From d65266a868985950af5c347955c7cb028d58b756 Mon Sep 17 00:00:00 2001 From: Sergey Klevtsov <141879860+sklevtsov-nvidia@users.noreply.github.com> Date: Tue, 22 Oct 2024 15:13:36 -0700 Subject: [PATCH] Add all supported GMMA shapes (#1890) --- include/cute/arch/mma_sm90.hpp | 3678 +- include/cute/arch/mma_sm90_gmma.hpp | 47002 +++--------- include/cute/arch/mma_sm90_gmma_ext.hpp | 56445 +++++++++++++++ include/cute/arch/mma_sm90_gmma_sparse.hpp | 49616 +++---------- .../cute/arch/mma_sm90_gmma_sparse_ext.hpp | 60445 ++++++++++++++++ include/cute/atom/mma_traits_sm90_gmma.hpp | 16153 +---- .../cute/atom/mma_traits_sm90_gmma_ext.hpp | 20116 +++++ .../cute/atom/mma_traits_sm90_gmma_sparse.hpp | 13407 +--- .../atom/mma_traits_sm90_gmma_sparse_ext.hpp | 17335 +++++ include/cute/numeric/integral_constant.hpp | 14 + 10 files changed, 180598 insertions(+), 103613 deletions(-) create mode 100644 include/cute/arch/mma_sm90_gmma_ext.hpp create mode 100644 include/cute/arch/mma_sm90_gmma_sparse_ext.hpp create mode 100644 include/cute/atom/mma_traits_sm90_gmma_ext.hpp create mode 100644 include/cute/atom/mma_traits_sm90_gmma_sparse_ext.hpp diff --git a/include/cute/arch/mma_sm90.hpp b/include/cute/arch/mma_sm90.hpp index cd96b2d53..51d34563c 100644 --- a/include/cute/arch/mma_sm90.hpp +++ b/include/cute/arch/mma_sm90.hpp @@ -380,66 +380,141 @@ ss_op_selector() if constexpr (Tile_N % 256 == 0) { return SM90::GMMA::MMA_64x256x16_F16F16F16_SS{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 248 == 0) { + return SM90::GMMA::MMA_64x248x16_F16F16F16_SS{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 240 == 0) { return SM90::GMMA::MMA_64x240x16_F16F16F16_SS{}; } #endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 232 == 0) { + return SM90::GMMA::MMA_64x232x16_F16F16F16_SS{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 224 == 0) { return SM90::GMMA::MMA_64x224x16_F16F16F16_SS{}; } #endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 216 == 0) { + return SM90::GMMA::MMA_64x216x16_F16F16F16_SS{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 208 == 0) { return SM90::GMMA::MMA_64x208x16_F16F16F16_SS{}; } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 200 == 0) { + return SM90::GMMA::MMA_64x200x16_F16F16F16_SS{}; + } #endif else if constexpr (Tile_N % 192 == 0) { return SM90::GMMA::MMA_64x192x16_F16F16F16_SS{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 184 == 0) { + return SM90::GMMA::MMA_64x184x16_F16F16F16_SS{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 176 == 0) { return SM90::GMMA::MMA_64x176x16_F16F16F16_SS{}; } #endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 168 == 0) { + return SM90::GMMA::MMA_64x168x16_F16F16F16_SS{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 160 == 0) { return SM90::GMMA::MMA_64x160x16_F16F16F16_SS{}; } #endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 152 == 0) { + return SM90::GMMA::MMA_64x152x16_F16F16F16_SS{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 144 == 0) { return SM90::GMMA::MMA_64x144x16_F16F16F16_SS{}; } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 136 == 0) { + return SM90::GMMA::MMA_64x136x16_F16F16F16_SS{}; + } #endif else if constexpr (Tile_N % 128 == 0) { return SM90::GMMA::MMA_64x128x16_F16F16F16_SS{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 120 == 0) { + return SM90::GMMA::MMA_64x120x16_F16F16F16_SS{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 112 == 0) { return SM90::GMMA::MMA_64x112x16_F16F16F16_SS{}; } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 104 == 0) { + return SM90::GMMA::MMA_64x104x16_F16F16F16_SS{}; + } #endif else if constexpr (Tile_N % 96 == 0) { return SM90::GMMA::MMA_64x96x16_F16F16F16_SS{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 88 == 0) { + return SM90::GMMA::MMA_64x88x16_F16F16F16_SS{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 80 == 0) { return SM90::GMMA::MMA_64x80x16_F16F16F16_SS{}; } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 72 == 0) { + return SM90::GMMA::MMA_64x72x16_F16F16F16_SS{}; + } #endif else if constexpr (Tile_N % 64 == 0) { return SM90::GMMA::MMA_64x64x16_F16F16F16_SS{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 56 == 0) { + return SM90::GMMA::MMA_64x56x16_F16F16F16_SS{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 48 == 0) { return SM90::GMMA::MMA_64x48x16_F16F16F16_SS{}; } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 40 == 0) { + return SM90::GMMA::MMA_64x40x16_F16F16F16_SS{}; + } #endif else if constexpr (Tile_N % 32 == 0) { return SM90::GMMA::MMA_64x32x16_F16F16F16_SS{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 24 == 0) { + return SM90::GMMA::MMA_64x24x16_F16F16F16_SS{}; + } +#endif else if constexpr (Tile_N % 16 == 0) { return SM90::GMMA::MMA_64x16x16_F16F16F16_SS{}; } @@ -460,66 +535,141 @@ ss_op_selector() if constexpr (Tile_N % 256 == 0) { return SM90::GMMA::MMA_64x256x32_F16E4M3E4M3_SS_TN{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 248 == 0) { + return SM90::GMMA::MMA_64x248x32_F16E4M3E4M3_SS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 240 == 0) { return SM90::GMMA::MMA_64x240x32_F16E4M3E4M3_SS_TN{}; } #endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 232 == 0) { + return SM90::GMMA::MMA_64x232x32_F16E4M3E4M3_SS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 224 == 0) { return SM90::GMMA::MMA_64x224x32_F16E4M3E4M3_SS_TN{}; } #endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 216 == 0) { + return SM90::GMMA::MMA_64x216x32_F16E4M3E4M3_SS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 208 == 0) { return SM90::GMMA::MMA_64x208x32_F16E4M3E4M3_SS_TN{}; } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 200 == 0) { + return SM90::GMMA::MMA_64x200x32_F16E4M3E4M3_SS_TN{}; + } #endif else if constexpr (Tile_N % 192 == 0) { return SM90::GMMA::MMA_64x192x32_F16E4M3E4M3_SS_TN{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 184 == 0) { + return SM90::GMMA::MMA_64x184x32_F16E4M3E4M3_SS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 176 == 0) { return SM90::GMMA::MMA_64x176x32_F16E4M3E4M3_SS_TN{}; } #endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 168 == 0) { + return SM90::GMMA::MMA_64x168x32_F16E4M3E4M3_SS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 160 == 0) { return SM90::GMMA::MMA_64x160x32_F16E4M3E4M3_SS_TN{}; } #endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 152 == 0) { + return SM90::GMMA::MMA_64x152x32_F16E4M3E4M3_SS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 144 == 0) { return SM90::GMMA::MMA_64x144x32_F16E4M3E4M3_SS_TN{}; } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 136 == 0) { + return SM90::GMMA::MMA_64x136x32_F16E4M3E4M3_SS_TN{}; + } #endif else if constexpr (Tile_N % 128 == 0) { return SM90::GMMA::MMA_64x128x32_F16E4M3E4M3_SS_TN{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 120 == 0) { + return SM90::GMMA::MMA_64x120x32_F16E4M3E4M3_SS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 112 == 0) { return SM90::GMMA::MMA_64x112x32_F16E4M3E4M3_SS_TN{}; } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 104 == 0) { + return SM90::GMMA::MMA_64x104x32_F16E4M3E4M3_SS_TN{}; + } #endif else if constexpr (Tile_N % 96 == 0) { return SM90::GMMA::MMA_64x96x32_F16E4M3E4M3_SS_TN{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 88 == 0) { + return SM90::GMMA::MMA_64x88x32_F16E4M3E4M3_SS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 80 == 0) { return SM90::GMMA::MMA_64x80x32_F16E4M3E4M3_SS_TN{}; } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 72 == 0) { + return SM90::GMMA::MMA_64x72x32_F16E4M3E4M3_SS_TN{}; + } #endif else if constexpr (Tile_N % 64 == 0) { return SM90::GMMA::MMA_64x64x32_F16E4M3E4M3_SS_TN{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 56 == 0) { + return SM90::GMMA::MMA_64x56x32_F16E4M3E4M3_SS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 48 == 0) { return SM90::GMMA::MMA_64x48x32_F16E4M3E4M3_SS_TN{}; } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 40 == 0) { + return SM90::GMMA::MMA_64x40x32_F16E4M3E4M3_SS_TN{}; + } #endif else if constexpr (Tile_N % 32 == 0) { return SM90::GMMA::MMA_64x32x32_F16E4M3E4M3_SS_TN{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 24 == 0) { + return SM90::GMMA::MMA_64x24x32_F16E4M3E4M3_SS_TN{}; + } +#endif else if constexpr (Tile_N % 16 == 0) { return SM90::GMMA::MMA_64x16x32_F16E4M3E4M3_SS_TN{}; } @@ -540,66 +690,141 @@ ss_op_selector() if constexpr (Tile_N % 256 == 0) { return SM90::GMMA::MMA_64x256x32_F16E4M3E5M2_SS_TN{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 248 == 0) { + return SM90::GMMA::MMA_64x248x32_F16E4M3E5M2_SS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 240 == 0) { return SM90::GMMA::MMA_64x240x32_F16E4M3E5M2_SS_TN{}; } #endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 232 == 0) { + return SM90::GMMA::MMA_64x232x32_F16E4M3E5M2_SS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 224 == 0) { return SM90::GMMA::MMA_64x224x32_F16E4M3E5M2_SS_TN{}; } #endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 216 == 0) { + return SM90::GMMA::MMA_64x216x32_F16E4M3E5M2_SS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 208 == 0) { return SM90::GMMA::MMA_64x208x32_F16E4M3E5M2_SS_TN{}; } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 200 == 0) { + return SM90::GMMA::MMA_64x200x32_F16E4M3E5M2_SS_TN{}; + } #endif else if constexpr (Tile_N % 192 == 0) { return SM90::GMMA::MMA_64x192x32_F16E4M3E5M2_SS_TN{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 184 == 0) { + return SM90::GMMA::MMA_64x184x32_F16E4M3E5M2_SS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 176 == 0) { return SM90::GMMA::MMA_64x176x32_F16E4M3E5M2_SS_TN{}; } #endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 168 == 0) { + return SM90::GMMA::MMA_64x168x32_F16E4M3E5M2_SS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 160 == 0) { return SM90::GMMA::MMA_64x160x32_F16E4M3E5M2_SS_TN{}; } #endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 152 == 0) { + return SM90::GMMA::MMA_64x152x32_F16E4M3E5M2_SS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 144 == 0) { return SM90::GMMA::MMA_64x144x32_F16E4M3E5M2_SS_TN{}; } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 136 == 0) { + return SM90::GMMA::MMA_64x136x32_F16E4M3E5M2_SS_TN{}; + } #endif else if constexpr (Tile_N % 128 == 0) { return SM90::GMMA::MMA_64x128x32_F16E4M3E5M2_SS_TN{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 120 == 0) { + return SM90::GMMA::MMA_64x120x32_F16E4M3E5M2_SS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 112 == 0) { return SM90::GMMA::MMA_64x112x32_F16E4M3E5M2_SS_TN{}; } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 104 == 0) { + return SM90::GMMA::MMA_64x104x32_F16E4M3E5M2_SS_TN{}; + } #endif else if constexpr (Tile_N % 96 == 0) { return SM90::GMMA::MMA_64x96x32_F16E4M3E5M2_SS_TN{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 88 == 0) { + return SM90::GMMA::MMA_64x88x32_F16E4M3E5M2_SS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 80 == 0) { return SM90::GMMA::MMA_64x80x32_F16E4M3E5M2_SS_TN{}; } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 72 == 0) { + return SM90::GMMA::MMA_64x72x32_F16E4M3E5M2_SS_TN{}; + } #endif else if constexpr (Tile_N % 64 == 0) { return SM90::GMMA::MMA_64x64x32_F16E4M3E5M2_SS_TN{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 56 == 0) { + return SM90::GMMA::MMA_64x56x32_F16E4M3E5M2_SS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 48 == 0) { return SM90::GMMA::MMA_64x48x32_F16E4M3E5M2_SS_TN{}; } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 40 == 0) { + return SM90::GMMA::MMA_64x40x32_F16E4M3E5M2_SS_TN{}; + } #endif else if constexpr (Tile_N % 32 == 0) { return SM90::GMMA::MMA_64x32x32_F16E4M3E5M2_SS_TN{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 24 == 0) { + return SM90::GMMA::MMA_64x24x32_F16E4M3E5M2_SS_TN{}; + } +#endif else if constexpr (Tile_N % 16 == 0) { return SM90::GMMA::MMA_64x16x32_F16E4M3E5M2_SS_TN{}; } @@ -620,66 +845,141 @@ ss_op_selector() if constexpr (Tile_N % 256 == 0) { return SM90::GMMA::MMA_64x256x32_F16E5M2E4M3_SS_TN{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 248 == 0) { + return SM90::GMMA::MMA_64x248x32_F16E5M2E4M3_SS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 240 == 0) { return SM90::GMMA::MMA_64x240x32_F16E5M2E4M3_SS_TN{}; } #endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 232 == 0) { + return SM90::GMMA::MMA_64x232x32_F16E5M2E4M3_SS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 224 == 0) { return SM90::GMMA::MMA_64x224x32_F16E5M2E4M3_SS_TN{}; } #endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 216 == 0) { + return SM90::GMMA::MMA_64x216x32_F16E5M2E4M3_SS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 208 == 0) { return SM90::GMMA::MMA_64x208x32_F16E5M2E4M3_SS_TN{}; } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 200 == 0) { + return SM90::GMMA::MMA_64x200x32_F16E5M2E4M3_SS_TN{}; + } #endif else if constexpr (Tile_N % 192 == 0) { return SM90::GMMA::MMA_64x192x32_F16E5M2E4M3_SS_TN{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 184 == 0) { + return SM90::GMMA::MMA_64x184x32_F16E5M2E4M3_SS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 176 == 0) { return SM90::GMMA::MMA_64x176x32_F16E5M2E4M3_SS_TN{}; } #endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 168 == 0) { + return SM90::GMMA::MMA_64x168x32_F16E5M2E4M3_SS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 160 == 0) { return SM90::GMMA::MMA_64x160x32_F16E5M2E4M3_SS_TN{}; } #endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 152 == 0) { + return SM90::GMMA::MMA_64x152x32_F16E5M2E4M3_SS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 144 == 0) { return SM90::GMMA::MMA_64x144x32_F16E5M2E4M3_SS_TN{}; } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 136 == 0) { + return SM90::GMMA::MMA_64x136x32_F16E5M2E4M3_SS_TN{}; + } #endif else if constexpr (Tile_N % 128 == 0) { return SM90::GMMA::MMA_64x128x32_F16E5M2E4M3_SS_TN{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 120 == 0) { + return SM90::GMMA::MMA_64x120x32_F16E5M2E4M3_SS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 112 == 0) { return SM90::GMMA::MMA_64x112x32_F16E5M2E4M3_SS_TN{}; } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 104 == 0) { + return SM90::GMMA::MMA_64x104x32_F16E5M2E4M3_SS_TN{}; + } #endif else if constexpr (Tile_N % 96 == 0) { return SM90::GMMA::MMA_64x96x32_F16E5M2E4M3_SS_TN{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 88 == 0) { + return SM90::GMMA::MMA_64x88x32_F16E5M2E4M3_SS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 80 == 0) { return SM90::GMMA::MMA_64x80x32_F16E5M2E4M3_SS_TN{}; } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 72 == 0) { + return SM90::GMMA::MMA_64x72x32_F16E5M2E4M3_SS_TN{}; + } #endif else if constexpr (Tile_N % 64 == 0) { return SM90::GMMA::MMA_64x64x32_F16E5M2E4M3_SS_TN{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 56 == 0) { + return SM90::GMMA::MMA_64x56x32_F16E5M2E4M3_SS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 48 == 0) { return SM90::GMMA::MMA_64x48x32_F16E5M2E4M3_SS_TN{}; } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 40 == 0) { + return SM90::GMMA::MMA_64x40x32_F16E5M2E4M3_SS_TN{}; + } #endif else if constexpr (Tile_N % 32 == 0) { return SM90::GMMA::MMA_64x32x32_F16E5M2E4M3_SS_TN{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 24 == 0) { + return SM90::GMMA::MMA_64x24x32_F16E5M2E4M3_SS_TN{}; + } +#endif else if constexpr (Tile_N % 16 == 0) { return SM90::GMMA::MMA_64x16x32_F16E5M2E4M3_SS_TN{}; } @@ -700,66 +1000,141 @@ ss_op_selector() if constexpr (Tile_N % 256 == 0) { return SM90::GMMA::MMA_64x256x32_F16E5M2E5M2_SS_TN{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 248 == 0) { + return SM90::GMMA::MMA_64x248x32_F16E5M2E5M2_SS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 240 == 0) { return SM90::GMMA::MMA_64x240x32_F16E5M2E5M2_SS_TN{}; } #endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 232 == 0) { + return SM90::GMMA::MMA_64x232x32_F16E5M2E5M2_SS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 224 == 0) { return SM90::GMMA::MMA_64x224x32_F16E5M2E5M2_SS_TN{}; } #endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 216 == 0) { + return SM90::GMMA::MMA_64x216x32_F16E5M2E5M2_SS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 208 == 0) { return SM90::GMMA::MMA_64x208x32_F16E5M2E5M2_SS_TN{}; } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 200 == 0) { + return SM90::GMMA::MMA_64x200x32_F16E5M2E5M2_SS_TN{}; + } #endif else if constexpr (Tile_N % 192 == 0) { return SM90::GMMA::MMA_64x192x32_F16E5M2E5M2_SS_TN{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 184 == 0) { + return SM90::GMMA::MMA_64x184x32_F16E5M2E5M2_SS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 176 == 0) { return SM90::GMMA::MMA_64x176x32_F16E5M2E5M2_SS_TN{}; } #endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 168 == 0) { + return SM90::GMMA::MMA_64x168x32_F16E5M2E5M2_SS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 160 == 0) { return SM90::GMMA::MMA_64x160x32_F16E5M2E5M2_SS_TN{}; } #endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 152 == 0) { + return SM90::GMMA::MMA_64x152x32_F16E5M2E5M2_SS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 144 == 0) { return SM90::GMMA::MMA_64x144x32_F16E5M2E5M2_SS_TN{}; } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 136 == 0) { + return SM90::GMMA::MMA_64x136x32_F16E5M2E5M2_SS_TN{}; + } #endif else if constexpr (Tile_N % 128 == 0) { return SM90::GMMA::MMA_64x128x32_F16E5M2E5M2_SS_TN{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 120 == 0) { + return SM90::GMMA::MMA_64x120x32_F16E5M2E5M2_SS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 112 == 0) { return SM90::GMMA::MMA_64x112x32_F16E5M2E5M2_SS_TN{}; } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 104 == 0) { + return SM90::GMMA::MMA_64x104x32_F16E5M2E5M2_SS_TN{}; + } #endif else if constexpr (Tile_N % 96 == 0) { return SM90::GMMA::MMA_64x96x32_F16E5M2E5M2_SS_TN{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 88 == 0) { + return SM90::GMMA::MMA_64x88x32_F16E5M2E5M2_SS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 80 == 0) { return SM90::GMMA::MMA_64x80x32_F16E5M2E5M2_SS_TN{}; } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 72 == 0) { + return SM90::GMMA::MMA_64x72x32_F16E5M2E5M2_SS_TN{}; + } #endif else if constexpr (Tile_N % 64 == 0) { return SM90::GMMA::MMA_64x64x32_F16E5M2E5M2_SS_TN{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 56 == 0) { + return SM90::GMMA::MMA_64x56x32_F16E5M2E5M2_SS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 48 == 0) { return SM90::GMMA::MMA_64x48x32_F16E5M2E5M2_SS_TN{}; } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 40 == 0) { + return SM90::GMMA::MMA_64x40x32_F16E5M2E5M2_SS_TN{}; + } #endif else if constexpr (Tile_N % 32 == 0) { return SM90::GMMA::MMA_64x32x32_F16E5M2E5M2_SS_TN{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 24 == 0) { + return SM90::GMMA::MMA_64x24x32_F16E5M2E5M2_SS_TN{}; + } +#endif else if constexpr (Tile_N % 16 == 0) { return SM90::GMMA::MMA_64x16x32_F16E5M2E5M2_SS_TN{}; } @@ -786,58 +1161,123 @@ ss_op_selector() if constexpr (Tile_N % 256 == 0) { return SM90::GMMA::MMA_64x256x16_F32F16F16_SS{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 248 == 0) { + return SM90::GMMA::MMA_64x248x16_F32F16F16_SS{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 240 == 0) { return SM90::GMMA::MMA_64x240x16_F32F16F16_SS{}; } #endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 232 == 0) { + return SM90::GMMA::MMA_64x232x16_F32F16F16_SS{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 224 == 0) { return SM90::GMMA::MMA_64x224x16_F32F16F16_SS{}; } #endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 216 == 0) { + return SM90::GMMA::MMA_64x216x16_F32F16F16_SS{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 208 == 0) { return SM90::GMMA::MMA_64x208x16_F32F16F16_SS{}; } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 200 == 0) { + return SM90::GMMA::MMA_64x200x16_F32F16F16_SS{}; + } #endif else if constexpr (Tile_N % 192 == 0) { return SM90::GMMA::MMA_64x192x16_F32F16F16_SS{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 184 == 0) { + return SM90::GMMA::MMA_64x184x16_F32F16F16_SS{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 176 == 0) { return SM90::GMMA::MMA_64x176x16_F32F16F16_SS{}; } #endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 168 == 0) { + return SM90::GMMA::MMA_64x168x16_F32F16F16_SS{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 160 == 0) { return SM90::GMMA::MMA_64x160x16_F32F16F16_SS{}; } #endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 152 == 0) { + return SM90::GMMA::MMA_64x152x16_F32F16F16_SS{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 144 == 0) { return SM90::GMMA::MMA_64x144x16_F32F16F16_SS{}; } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 136 == 0) { + return SM90::GMMA::MMA_64x136x16_F32F16F16_SS{}; + } #endif else if constexpr (Tile_N % 128 == 0) { return SM90::GMMA::MMA_64x128x16_F32F16F16_SS{}; } #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) - else if constexpr (Tile_N % 112 == 0) { - return SM90::GMMA::MMA_64x112x16_F32F16F16_SS{}; + else if constexpr (Tile_N % 120 == 0) { + return SM90::GMMA::MMA_64x120x16_F32F16F16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 112 == 0) { + return SM90::GMMA::MMA_64x112x16_F32F16F16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 104 == 0) { + return SM90::GMMA::MMA_64x104x16_F32F16F16_SS{}; } #endif else if constexpr (Tile_N % 96 == 0) { return SM90::GMMA::MMA_64x96x16_F32F16F16_SS{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 88 == 0) { + return SM90::GMMA::MMA_64x88x16_F32F16F16_SS{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 80 == 0) { return SM90::GMMA::MMA_64x80x16_F32F16F16_SS{}; } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 72 == 0) { + return SM90::GMMA::MMA_64x72x16_F32F16F16_SS{}; + } #endif else if constexpr (Tile_N % 64 == 0) { return SM90::GMMA::MMA_64x64x16_F32F16F16_SS{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 56 == 0) { + return SM90::GMMA::MMA_64x56x16_F32F16F16_SS{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 48 == 0) { return SM90::GMMA::MMA_64x48x16_F32F16F16_SS{}; @@ -851,6 +1291,11 @@ ss_op_selector() else if constexpr (Tile_N % 32 == 0) { return SM90::GMMA::MMA_64x32x16_F32F16F16_SS{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 24 == 0) { + return SM90::GMMA::MMA_64x24x16_F32F16F16_SS{}; + } +#endif else if constexpr (Tile_N % 16 == 0) { return SM90::GMMA::MMA_64x16x16_F32F16F16_SS{}; } @@ -869,58 +1314,123 @@ ss_op_selector() if constexpr (Tile_N % 256 == 0) { return SM90::GMMA::MMA_64x256x16_F32BF16BF16_SS{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 248 == 0) { + return SM90::GMMA::MMA_64x248x16_F32BF16BF16_SS{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 240 == 0) { return SM90::GMMA::MMA_64x240x16_F32BF16BF16_SS{}; } #endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 232 == 0) { + return SM90::GMMA::MMA_64x232x16_F32BF16BF16_SS{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 224 == 0) { return SM90::GMMA::MMA_64x224x16_F32BF16BF16_SS{}; } #endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 216 == 0) { + return SM90::GMMA::MMA_64x216x16_F32BF16BF16_SS{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 208 == 0) { return SM90::GMMA::MMA_64x208x16_F32BF16BF16_SS{}; } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 200 == 0) { + return SM90::GMMA::MMA_64x200x16_F32BF16BF16_SS{}; + } #endif else if constexpr (Tile_N % 192 == 0) { return SM90::GMMA::MMA_64x192x16_F32BF16BF16_SS{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 184 == 0) { + return SM90::GMMA::MMA_64x184x16_F32BF16BF16_SS{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 176 == 0) { return SM90::GMMA::MMA_64x176x16_F32BF16BF16_SS{}; } #endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 168 == 0) { + return SM90::GMMA::MMA_64x168x16_F32BF16BF16_SS{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 160 == 0) { return SM90::GMMA::MMA_64x160x16_F32BF16BF16_SS{}; } #endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 152 == 0) { + return SM90::GMMA::MMA_64x152x16_F32BF16BF16_SS{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 144 == 0) { return SM90::GMMA::MMA_64x144x16_F32BF16BF16_SS{}; } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 136 == 0) { + return SM90::GMMA::MMA_64x136x16_F32BF16BF16_SS{}; + } #endif else if constexpr (Tile_N % 128 == 0) { return SM90::GMMA::MMA_64x128x16_F32BF16BF16_SS{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 120 == 0) { + return SM90::GMMA::MMA_64x120x16_F32BF16BF16_SS{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 112 == 0) { return SM90::GMMA::MMA_64x112x16_F32BF16BF16_SS{}; } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 104 == 0) { + return SM90::GMMA::MMA_64x104x16_F32BF16BF16_SS{}; + } #endif else if constexpr (Tile_N % 96 == 0) { return SM90::GMMA::MMA_64x96x16_F32BF16BF16_SS{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 88 == 0) { + return SM90::GMMA::MMA_64x88x16_F32BF16BF16_SS{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 80 == 0) { return SM90::GMMA::MMA_64x80x16_F32BF16BF16_SS{}; } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 72 == 0) { + return SM90::GMMA::MMA_64x72x16_F32BF16BF16_SS{}; + } #endif else if constexpr (Tile_N % 64 == 0) { return SM90::GMMA::MMA_64x64x16_F32BF16BF16_SS{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 56 == 0) { + return SM90::GMMA::MMA_64x56x16_F32BF16BF16_SS{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 48 == 0) { return SM90::GMMA::MMA_64x48x16_F32BF16BF16_SS{}; @@ -934,6 +1444,11 @@ ss_op_selector() else if constexpr (Tile_N % 32 == 0) { return SM90::GMMA::MMA_64x32x16_F32BF16BF16_SS{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 24 == 0) { + return SM90::GMMA::MMA_64x24x16_F32BF16BF16_SS{}; + } +#endif else if constexpr (Tile_N % 16 == 0) { return SM90::GMMA::MMA_64x16x16_F32BF16BF16_SS{}; } @@ -954,66 +1469,141 @@ ss_op_selector() if constexpr (Tile_N % 256 == 0) { return SM90::GMMA::MMA_64x256x8_F32TF32TF32_SS_TN{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 248 == 0) { + return SM90::GMMA::MMA_64x248x8_F32TF32TF32_SS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 240 == 0) { return SM90::GMMA::MMA_64x240x8_F32TF32TF32_SS_TN{}; } #endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 232 == 0) { + return SM90::GMMA::MMA_64x232x8_F32TF32TF32_SS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 224 == 0) { return SM90::GMMA::MMA_64x224x8_F32TF32TF32_SS_TN{}; } #endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 216 == 0) { + return SM90::GMMA::MMA_64x216x8_F32TF32TF32_SS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 208 == 0) { return SM90::GMMA::MMA_64x208x8_F32TF32TF32_SS_TN{}; } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 200 == 0) { + return SM90::GMMA::MMA_64x200x8_F32TF32TF32_SS_TN{}; + } #endif else if constexpr (Tile_N % 192 == 0) { return SM90::GMMA::MMA_64x192x8_F32TF32TF32_SS_TN{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 184 == 0) { + return SM90::GMMA::MMA_64x184x8_F32TF32TF32_SS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 176 == 0) { return SM90::GMMA::MMA_64x176x8_F32TF32TF32_SS_TN{}; } #endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 168 == 0) { + return SM90::GMMA::MMA_64x168x8_F32TF32TF32_SS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 160 == 0) { return SM90::GMMA::MMA_64x160x8_F32TF32TF32_SS_TN{}; } #endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 152 == 0) { + return SM90::GMMA::MMA_64x152x8_F32TF32TF32_SS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 144 == 0) { return SM90::GMMA::MMA_64x144x8_F32TF32TF32_SS_TN{}; } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 136 == 0) { + return SM90::GMMA::MMA_64x136x8_F32TF32TF32_SS_TN{}; + } #endif else if constexpr (Tile_N % 128 == 0) { return SM90::GMMA::MMA_64x128x8_F32TF32TF32_SS_TN{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 120 == 0) { + return SM90::GMMA::MMA_64x120x8_F32TF32TF32_SS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 112 == 0) { return SM90::GMMA::MMA_64x112x8_F32TF32TF32_SS_TN{}; } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 104 == 0) { + return SM90::GMMA::MMA_64x104x8_F32TF32TF32_SS_TN{}; + } #endif else if constexpr (Tile_N % 96 == 0) { return SM90::GMMA::MMA_64x96x8_F32TF32TF32_SS_TN{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 88 == 0) { + return SM90::GMMA::MMA_64x88x8_F32TF32TF32_SS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 80 == 0) { return SM90::GMMA::MMA_64x80x8_F32TF32TF32_SS_TN{}; } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 72 == 0) { + return SM90::GMMA::MMA_64x72x8_F32TF32TF32_SS_TN{}; + } #endif else if constexpr (Tile_N % 64 == 0) { return SM90::GMMA::MMA_64x64x8_F32TF32TF32_SS_TN{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 56 == 0) { + return SM90::GMMA::MMA_64x56x8_F32TF32TF32_SS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 48 == 0) { return SM90::GMMA::MMA_64x48x8_F32TF32TF32_SS_TN{}; } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 40 == 0) { + return SM90::GMMA::MMA_64x40x8_F32TF32TF32_SS_TN{}; + } #endif else if constexpr (Tile_N % 32 == 0) { return SM90::GMMA::MMA_64x32x8_F32TF32TF32_SS_TN{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 24 == 0) { + return SM90::GMMA::MMA_64x24x8_F32TF32TF32_SS_TN{}; + } +#endif else if constexpr (Tile_N % 16 == 0) { return SM90::GMMA::MMA_64x16x8_F32TF32TF32_SS_TN{}; } @@ -1034,66 +1624,141 @@ ss_op_selector() if constexpr (Tile_N % 256 == 0) { return SM90::GMMA::MMA_64x256x32_F32E4M3E4M3_SS_TN{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 248 == 0) { + return SM90::GMMA::MMA_64x248x32_F32E4M3E4M3_SS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 240 == 0) { return SM90::GMMA::MMA_64x240x32_F32E4M3E4M3_SS_TN{}; } #endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 232 == 0) { + return SM90::GMMA::MMA_64x232x32_F32E4M3E4M3_SS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 224 == 0) { return SM90::GMMA::MMA_64x224x32_F32E4M3E4M3_SS_TN{}; } #endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 216 == 0) { + return SM90::GMMA::MMA_64x216x32_F32E4M3E4M3_SS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 208 == 0) { return SM90::GMMA::MMA_64x208x32_F32E4M3E4M3_SS_TN{}; } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 200 == 0) { + return SM90::GMMA::MMA_64x200x32_F32E4M3E4M3_SS_TN{}; + } #endif else if constexpr (Tile_N % 192 == 0) { return SM90::GMMA::MMA_64x192x32_F32E4M3E4M3_SS_TN{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 184 == 0) { + return SM90::GMMA::MMA_64x184x32_F32E4M3E4M3_SS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 176 == 0) { return SM90::GMMA::MMA_64x176x32_F32E4M3E4M3_SS_TN{}; } #endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 168 == 0) { + return SM90::GMMA::MMA_64x168x32_F32E4M3E4M3_SS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 160 == 0) { return SM90::GMMA::MMA_64x160x32_F32E4M3E4M3_SS_TN{}; } #endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 152 == 0) { + return SM90::GMMA::MMA_64x152x32_F32E4M3E4M3_SS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 144 == 0) { return SM90::GMMA::MMA_64x144x32_F32E4M3E4M3_SS_TN{}; } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 136 == 0) { + return SM90::GMMA::MMA_64x136x32_F32E4M3E4M3_SS_TN{}; + } #endif else if constexpr (Tile_N % 128 == 0) { return SM90::GMMA::MMA_64x128x32_F32E4M3E4M3_SS_TN{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 120 == 0) { + return SM90::GMMA::MMA_64x120x32_F32E4M3E4M3_SS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 112 == 0) { return SM90::GMMA::MMA_64x112x32_F32E4M3E4M3_SS_TN{}; } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 104 == 0) { + return SM90::GMMA::MMA_64x104x32_F32E4M3E4M3_SS_TN{}; + } #endif else if constexpr (Tile_N % 96 == 0) { return SM90::GMMA::MMA_64x96x32_F32E4M3E4M3_SS_TN{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 88 == 0) { + return SM90::GMMA::MMA_64x88x32_F32E4M3E4M3_SS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 80 == 0) { return SM90::GMMA::MMA_64x80x32_F32E4M3E4M3_SS_TN{}; } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 72 == 0) { + return SM90::GMMA::MMA_64x72x32_F32E4M3E4M3_SS_TN{}; + } #endif else if constexpr (Tile_N % 64 == 0) { return SM90::GMMA::MMA_64x64x32_F32E4M3E4M3_SS_TN{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 56 == 0) { + return SM90::GMMA::MMA_64x56x32_F32E4M3E4M3_SS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 48 == 0) { return SM90::GMMA::MMA_64x48x32_F32E4M3E4M3_SS_TN{}; } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 40 == 0) { + return SM90::GMMA::MMA_64x40x32_F32E4M3E4M3_SS_TN{}; + } #endif else if constexpr (Tile_N % 32 == 0) { return SM90::GMMA::MMA_64x32x32_F32E4M3E4M3_SS_TN{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 24 == 0) { + return SM90::GMMA::MMA_64x24x32_F32E4M3E4M3_SS_TN{}; + } +#endif else if constexpr (Tile_N % 16 == 0) { return SM90::GMMA::MMA_64x16x32_F32E4M3E4M3_SS_TN{}; } @@ -1114,66 +1779,141 @@ ss_op_selector() if constexpr (Tile_N % 256 == 0) { return SM90::GMMA::MMA_64x256x32_F32E4M3E5M2_SS_TN{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 248 == 0) { + return SM90::GMMA::MMA_64x248x32_F32E4M3E5M2_SS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 240 == 0) { return SM90::GMMA::MMA_64x240x32_F32E4M3E5M2_SS_TN{}; } #endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 232 == 0) { + return SM90::GMMA::MMA_64x232x32_F32E4M3E5M2_SS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 224 == 0) { return SM90::GMMA::MMA_64x224x32_F32E4M3E5M2_SS_TN{}; } #endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 216 == 0) { + return SM90::GMMA::MMA_64x216x32_F32E4M3E5M2_SS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 208 == 0) { return SM90::GMMA::MMA_64x208x32_F32E4M3E5M2_SS_TN{}; } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 200 == 0) { + return SM90::GMMA::MMA_64x200x32_F32E4M3E5M2_SS_TN{}; + } #endif else if constexpr (Tile_N % 192 == 0) { return SM90::GMMA::MMA_64x192x32_F32E4M3E5M2_SS_TN{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 184 == 0) { + return SM90::GMMA::MMA_64x184x32_F32E4M3E5M2_SS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 176 == 0) { return SM90::GMMA::MMA_64x176x32_F32E4M3E5M2_SS_TN{}; } #endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 168 == 0) { + return SM90::GMMA::MMA_64x168x32_F32E4M3E5M2_SS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 160 == 0) { return SM90::GMMA::MMA_64x160x32_F32E4M3E5M2_SS_TN{}; } #endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 152 == 0) { + return SM90::GMMA::MMA_64x152x32_F32E4M3E5M2_SS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 144 == 0) { return SM90::GMMA::MMA_64x144x32_F32E4M3E5M2_SS_TN{}; } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 136 == 0) { + return SM90::GMMA::MMA_64x136x32_F32E4M3E5M2_SS_TN{}; + } #endif else if constexpr (Tile_N % 128 == 0) { return SM90::GMMA::MMA_64x128x32_F32E4M3E5M2_SS_TN{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 120 == 0) { + return SM90::GMMA::MMA_64x120x32_F32E4M3E5M2_SS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 112 == 0) { return SM90::GMMA::MMA_64x112x32_F32E4M3E5M2_SS_TN{}; } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 104 == 0) { + return SM90::GMMA::MMA_64x104x32_F32E4M3E5M2_SS_TN{}; + } #endif else if constexpr (Tile_N % 96 == 0) { return SM90::GMMA::MMA_64x96x32_F32E4M3E5M2_SS_TN{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 88 == 0) { + return SM90::GMMA::MMA_64x88x32_F32E4M3E5M2_SS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 80 == 0) { return SM90::GMMA::MMA_64x80x32_F32E4M3E5M2_SS_TN{}; } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 72 == 0) { + return SM90::GMMA::MMA_64x72x32_F32E4M3E5M2_SS_TN{}; + } #endif else if constexpr (Tile_N % 64 == 0) { return SM90::GMMA::MMA_64x64x32_F32E4M3E5M2_SS_TN{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 56 == 0) { + return SM90::GMMA::MMA_64x56x32_F32E4M3E5M2_SS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 48 == 0) { return SM90::GMMA::MMA_64x48x32_F32E4M3E5M2_SS_TN{}; } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 40 == 0) { + return SM90::GMMA::MMA_64x40x32_F32E4M3E5M2_SS_TN{}; + } #endif else if constexpr (Tile_N % 32 == 0) { return SM90::GMMA::MMA_64x32x32_F32E4M3E5M2_SS_TN{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 24 == 0) { + return SM90::GMMA::MMA_64x24x32_F32E4M3E5M2_SS_TN{}; + } +#endif else if constexpr (Tile_N % 16 == 0) { return SM90::GMMA::MMA_64x16x32_F32E4M3E5M2_SS_TN{}; } @@ -1194,66 +1934,141 @@ ss_op_selector() if constexpr (Tile_N % 256 == 0) { return SM90::GMMA::MMA_64x256x32_F32E5M2E4M3_SS_TN{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 248 == 0) { + return SM90::GMMA::MMA_64x248x32_F32E5M2E4M3_SS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 240 == 0) { return SM90::GMMA::MMA_64x240x32_F32E5M2E4M3_SS_TN{}; } #endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 232 == 0) { + return SM90::GMMA::MMA_64x232x32_F32E5M2E4M3_SS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 224 == 0) { return SM90::GMMA::MMA_64x224x32_F32E5M2E4M3_SS_TN{}; } #endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 216 == 0) { + return SM90::GMMA::MMA_64x216x32_F32E5M2E4M3_SS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 208 == 0) { return SM90::GMMA::MMA_64x208x32_F32E5M2E4M3_SS_TN{}; } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 200 == 0) { + return SM90::GMMA::MMA_64x200x32_F32E5M2E4M3_SS_TN{}; + } #endif else if constexpr (Tile_N % 192 == 0) { return SM90::GMMA::MMA_64x192x32_F32E5M2E4M3_SS_TN{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 184 == 0) { + return SM90::GMMA::MMA_64x184x32_F32E5M2E4M3_SS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 176 == 0) { return SM90::GMMA::MMA_64x176x32_F32E5M2E4M3_SS_TN{}; } #endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 168 == 0) { + return SM90::GMMA::MMA_64x168x32_F32E5M2E4M3_SS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 160 == 0) { return SM90::GMMA::MMA_64x160x32_F32E5M2E4M3_SS_TN{}; } #endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 152 == 0) { + return SM90::GMMA::MMA_64x152x32_F32E5M2E4M3_SS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 144 == 0) { return SM90::GMMA::MMA_64x144x32_F32E5M2E4M3_SS_TN{}; } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 136 == 0) { + return SM90::GMMA::MMA_64x136x32_F32E5M2E4M3_SS_TN{}; + } #endif else if constexpr (Tile_N % 128 == 0) { return SM90::GMMA::MMA_64x128x32_F32E5M2E4M3_SS_TN{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 120 == 0) { + return SM90::GMMA::MMA_64x120x32_F32E5M2E4M3_SS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 112 == 0) { return SM90::GMMA::MMA_64x112x32_F32E5M2E4M3_SS_TN{}; } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 104 == 0) { + return SM90::GMMA::MMA_64x104x32_F32E5M2E4M3_SS_TN{}; + } #endif else if constexpr (Tile_N % 96 == 0) { return SM90::GMMA::MMA_64x96x32_F32E5M2E4M3_SS_TN{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 88 == 0) { + return SM90::GMMA::MMA_64x88x32_F32E5M2E4M3_SS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 80 == 0) { return SM90::GMMA::MMA_64x80x32_F32E5M2E4M3_SS_TN{}; } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 72 == 0) { + return SM90::GMMA::MMA_64x72x32_F32E5M2E4M3_SS_TN{}; + } #endif else if constexpr (Tile_N % 64 == 0) { return SM90::GMMA::MMA_64x64x32_F32E5M2E4M3_SS_TN{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 56 == 0) { + return SM90::GMMA::MMA_64x56x32_F32E5M2E4M3_SS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 48 == 0) { return SM90::GMMA::MMA_64x48x32_F32E5M2E4M3_SS_TN{}; } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 40 == 0) { + return SM90::GMMA::MMA_64x40x32_F32E5M2E4M3_SS_TN{}; + } #endif else if constexpr (Tile_N % 32 == 0) { return SM90::GMMA::MMA_64x32x32_F32E5M2E4M3_SS_TN{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 24 == 0) { + return SM90::GMMA::MMA_64x24x32_F32E5M2E4M3_SS_TN{}; + } +#endif else if constexpr (Tile_N % 16 == 0) { return SM90::GMMA::MMA_64x16x32_F32E5M2E4M3_SS_TN{}; } @@ -1274,66 +2089,141 @@ ss_op_selector() if constexpr (Tile_N % 256 == 0) { return SM90::GMMA::MMA_64x256x32_F32E5M2E5M2_SS_TN{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 248 == 0) { + return SM90::GMMA::MMA_64x248x32_F32E5M2E5M2_SS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 240 == 0) { return SM90::GMMA::MMA_64x240x32_F32E5M2E5M2_SS_TN{}; } #endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 232 == 0) { + return SM90::GMMA::MMA_64x232x32_F32E5M2E5M2_SS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 224 == 0) { return SM90::GMMA::MMA_64x224x32_F32E5M2E5M2_SS_TN{}; } #endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 216 == 0) { + return SM90::GMMA::MMA_64x216x32_F32E5M2E5M2_SS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 208 == 0) { return SM90::GMMA::MMA_64x208x32_F32E5M2E5M2_SS_TN{}; } #endif - else if constexpr (Tile_N % 192 == 0) { +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 200 == 0) { + return SM90::GMMA::MMA_64x200x32_F32E5M2E5M2_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 192 == 0) { return SM90::GMMA::MMA_64x192x32_F32E5M2E5M2_SS_TN{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 184 == 0) { + return SM90::GMMA::MMA_64x184x32_F32E5M2E5M2_SS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 176 == 0) { return SM90::GMMA::MMA_64x176x32_F32E5M2E5M2_SS_TN{}; } #endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 168 == 0) { + return SM90::GMMA::MMA_64x168x32_F32E5M2E5M2_SS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 160 == 0) { return SM90::GMMA::MMA_64x160x32_F32E5M2E5M2_SS_TN{}; } #endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 152 == 0) { + return SM90::GMMA::MMA_64x152x32_F32E5M2E5M2_SS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 144 == 0) { return SM90::GMMA::MMA_64x144x32_F32E5M2E5M2_SS_TN{}; } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 136 == 0) { + return SM90::GMMA::MMA_64x136x32_F32E5M2E5M2_SS_TN{}; + } #endif else if constexpr (Tile_N % 128 == 0) { return SM90::GMMA::MMA_64x128x32_F32E5M2E5M2_SS_TN{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 120 == 0) { + return SM90::GMMA::MMA_64x120x32_F32E5M2E5M2_SS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 112 == 0) { return SM90::GMMA::MMA_64x112x32_F32E5M2E5M2_SS_TN{}; } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 104 == 0) { + return SM90::GMMA::MMA_64x104x32_F32E5M2E5M2_SS_TN{}; + } #endif else if constexpr (Tile_N % 96 == 0) { return SM90::GMMA::MMA_64x96x32_F32E5M2E5M2_SS_TN{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 88 == 0) { + return SM90::GMMA::MMA_64x88x32_F32E5M2E5M2_SS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 80 == 0) { return SM90::GMMA::MMA_64x80x32_F32E5M2E5M2_SS_TN{}; } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 72 == 0) { + return SM90::GMMA::MMA_64x72x32_F32E5M2E5M2_SS_TN{}; + } #endif else if constexpr (Tile_N % 64 == 0) { return SM90::GMMA::MMA_64x64x32_F32E5M2E5M2_SS_TN{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 56 == 0) { + return SM90::GMMA::MMA_64x56x32_F32E5M2E5M2_SS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 48 == 0) { return SM90::GMMA::MMA_64x48x32_F32E5M2E5M2_SS_TN{}; } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 40 == 0) { + return SM90::GMMA::MMA_64x40x32_F32E5M2E5M2_SS_TN{}; + } #endif else if constexpr (Tile_N % 32 == 0) { return SM90::GMMA::MMA_64x32x32_F32E5M2E5M2_SS_TN{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 24 == 0) { + return SM90::GMMA::MMA_64x24x32_F32E5M2E5M2_SS_TN{}; + } +#endif else if constexpr (Tile_N % 16 == 0) { return SM90::GMMA::MMA_64x16x32_F32E5M2E5M2_SS_TN{}; } @@ -1422,6 +2312,11 @@ ss_op_selector() else if constexpr (Tile_N % 32 == 0) { return SM90::GMMA::MMA_64x32x32_S32S8S8_SS_TN{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 24 == 0) { + return SM90::GMMA::MMA_64x24x32_S32S8S8_SS_TN{}; + } +#endif else if constexpr (Tile_N % 16 == 0) { return SM90::GMMA::MMA_64x16x32_S32S8S8_SS_TN{}; } @@ -1502,6 +2397,11 @@ ss_op_selector() else if constexpr (Tile_N % 32 == 0) { return SM90::GMMA::MMA_64x32x32_S32S8U8_SS_TN{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 24 == 0) { + return SM90::GMMA::MMA_64x24x32_S32S8U8_SS_TN{}; + } +#endif else if constexpr (Tile_N % 16 == 0) { return SM90::GMMA::MMA_64x16x32_S32S8U8_SS_TN{}; } @@ -1582,6 +2482,11 @@ ss_op_selector() else if constexpr (Tile_N % 32 == 0) { return SM90::GMMA::MMA_64x32x32_S32U8S8_SS_TN{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 24 == 0) { + return SM90::GMMA::MMA_64x24x32_S32U8S8_SS_TN{}; + } +#endif else if constexpr (Tile_N % 16 == 0) { return SM90::GMMA::MMA_64x16x32_S32U8S8_SS_TN{}; } @@ -1662,6 +2567,11 @@ ss_op_selector() else if constexpr (Tile_N % 32 == 0) { return SM90::GMMA::MMA_64x32x32_S32U8U8_SS_TN{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 24 == 0) { + return SM90::GMMA::MMA_64x24x32_S32U8U8_SS_TN{}; + } +#endif else if constexpr (Tile_N % 16 == 0) { return SM90::GMMA::MMA_64x16x32_S32U8U8_SS_TN{}; } @@ -1713,66 +2623,141 @@ ss_op_selector_sparse() if constexpr (Tile_N % 256 == 0) { return SM90::GMMA::SPARSE::GMMA_64x256x32_F16F16F16_SS{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 248 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x248x32_F16F16F16_SS{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 240 == 0) { return SM90::GMMA::SPARSE::GMMA_64x240x32_F16F16F16_SS{}; } #endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 232 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x232x32_F16F16F16_SS{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 224 == 0) { return SM90::GMMA::SPARSE::GMMA_64x224x32_F16F16F16_SS{}; } #endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 216 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x216x32_F16F16F16_SS{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 208 == 0) { return SM90::GMMA::SPARSE::GMMA_64x208x32_F16F16F16_SS{}; } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 200 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x200x32_F16F16F16_SS{}; + } #endif else if constexpr (Tile_N % 192 == 0) { return SM90::GMMA::SPARSE::GMMA_64x192x32_F16F16F16_SS{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 184 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x184x32_F16F16F16_SS{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 176 == 0) { return SM90::GMMA::SPARSE::GMMA_64x176x32_F16F16F16_SS{}; } #endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 168 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x168x32_F16F16F16_SS{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 160 == 0) { return SM90::GMMA::SPARSE::GMMA_64x160x32_F16F16F16_SS{}; } #endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 152 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x152x32_F16F16F16_SS{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 144 == 0) { return SM90::GMMA::SPARSE::GMMA_64x144x32_F16F16F16_SS{}; } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 136 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x136x32_F16F16F16_SS{}; + } #endif else if constexpr (Tile_N % 128 == 0) { return SM90::GMMA::SPARSE::GMMA_64x128x32_F16F16F16_SS{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 120 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x120x32_F16F16F16_SS{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 112 == 0) { return SM90::GMMA::SPARSE::GMMA_64x112x32_F16F16F16_SS{}; } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 104 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x104x32_F16F16F16_SS{}; + } #endif else if constexpr (Tile_N % 96 == 0) { return SM90::GMMA::SPARSE::GMMA_64x96x32_F16F16F16_SS{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 88 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x88x32_F16F16F16_SS{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 80 == 0) { return SM90::GMMA::SPARSE::GMMA_64x80x32_F16F16F16_SS{}; } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 72 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x72x32_F16F16F16_SS{}; + } #endif else if constexpr (Tile_N % 64 == 0) { return SM90::GMMA::SPARSE::GMMA_64x64x32_F16F16F16_SS{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 56 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x56x32_F16F16F16_SS{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 48 == 0) { return SM90::GMMA::SPARSE::GMMA_64x48x32_F16F16F16_SS{}; } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 40 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x40x32_F16F16F16_SS{}; + } #endif else if constexpr (Tile_N % 32 == 0) { return SM90::GMMA::SPARSE::GMMA_64x32x32_F16F16F16_SS{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 24 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x24x32_F16F16F16_SS{}; + } +#endif else if constexpr (Tile_N % 16 == 0) { return SM90::GMMA::SPARSE::GMMA_64x16x32_F16F16F16_SS{}; } @@ -1793,66 +2778,141 @@ ss_op_selector_sparse() if constexpr (Tile_N % 256 == 0) { return SM90::GMMA::SPARSE::GMMA_64x256x64_F16E4M3E4M3_SS_TN{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 248 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x248x64_F16E4M3E4M3_SS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 240 == 0) { return SM90::GMMA::SPARSE::GMMA_64x240x64_F16E4M3E4M3_SS_TN{}; } #endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 232 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x232x64_F16E4M3E4M3_SS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 224 == 0) { return SM90::GMMA::SPARSE::GMMA_64x224x64_F16E4M3E4M3_SS_TN{}; } #endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 216 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x216x64_F16E4M3E4M3_SS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 208 == 0) { return SM90::GMMA::SPARSE::GMMA_64x208x64_F16E4M3E4M3_SS_TN{}; } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 200 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x200x64_F16E4M3E4M3_SS_TN{}; + } #endif else if constexpr (Tile_N % 192 == 0) { return SM90::GMMA::SPARSE::GMMA_64x192x64_F16E4M3E4M3_SS_TN{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 184 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x184x64_F16E4M3E4M3_SS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 176 == 0) { return SM90::GMMA::SPARSE::GMMA_64x176x64_F16E4M3E4M3_SS_TN{}; } #endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 168 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x168x64_F16E4M3E4M3_SS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 160 == 0) { return SM90::GMMA::SPARSE::GMMA_64x160x64_F16E4M3E4M3_SS_TN{}; } #endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 152 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x152x64_F16E4M3E4M3_SS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 144 == 0) { return SM90::GMMA::SPARSE::GMMA_64x144x64_F16E4M3E4M3_SS_TN{}; } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 136 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x136x64_F16E4M3E4M3_SS_TN{}; + } #endif else if constexpr (Tile_N % 128 == 0) { return SM90::GMMA::SPARSE::GMMA_64x128x64_F16E4M3E4M3_SS_TN{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 120 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x120x64_F16E4M3E4M3_SS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 112 == 0) { return SM90::GMMA::SPARSE::GMMA_64x112x64_F16E4M3E4M3_SS_TN{}; } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 104 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x104x64_F16E4M3E4M3_SS_TN{}; + } #endif else if constexpr (Tile_N % 96 == 0) { return SM90::GMMA::SPARSE::GMMA_64x96x64_F16E4M3E4M3_SS_TN{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 88 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x88x64_F16E4M3E4M3_SS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 80 == 0) { return SM90::GMMA::SPARSE::GMMA_64x80x64_F16E4M3E4M3_SS_TN{}; } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 72 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x72x64_F16E4M3E4M3_SS_TN{}; + } #endif else if constexpr (Tile_N % 64 == 0) { return SM90::GMMA::SPARSE::GMMA_64x64x64_F16E4M3E4M3_SS_TN{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 56 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x56x64_F16E4M3E4M3_SS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 48 == 0) { return SM90::GMMA::SPARSE::GMMA_64x48x64_F16E4M3E4M3_SS_TN{}; } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 40 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x40x64_F16E4M3E4M3_SS_TN{}; + } #endif else if constexpr (Tile_N % 32 == 0) { return SM90::GMMA::SPARSE::GMMA_64x32x64_F16E4M3E4M3_SS_TN{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 24 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x24x64_F16E4M3E4M3_SS_TN{}; + } +#endif else if constexpr (Tile_N % 16 == 0) { return SM90::GMMA::SPARSE::GMMA_64x16x64_F16E4M3E4M3_SS_TN{}; } @@ -1873,66 +2933,141 @@ ss_op_selector_sparse() if constexpr (Tile_N % 256 == 0) { return SM90::GMMA::SPARSE::GMMA_64x256x64_F16E4M3E5M2_SS_TN{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 248 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x248x64_F16E4M3E5M2_SS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 240 == 0) { return SM90::GMMA::SPARSE::GMMA_64x240x64_F16E4M3E5M2_SS_TN{}; } #endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 232 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x232x64_F16E4M3E5M2_SS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 224 == 0) { return SM90::GMMA::SPARSE::GMMA_64x224x64_F16E4M3E5M2_SS_TN{}; } #endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 216 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x216x64_F16E4M3E5M2_SS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 208 == 0) { return SM90::GMMA::SPARSE::GMMA_64x208x64_F16E4M3E5M2_SS_TN{}; } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 200 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x200x64_F16E4M3E5M2_SS_TN{}; + } #endif else if constexpr (Tile_N % 192 == 0) { return SM90::GMMA::SPARSE::GMMA_64x192x64_F16E4M3E5M2_SS_TN{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 184 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x184x64_F16E4M3E5M2_SS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 176 == 0) { return SM90::GMMA::SPARSE::GMMA_64x176x64_F16E4M3E5M2_SS_TN{}; } #endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 168 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x168x64_F16E4M3E5M2_SS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 160 == 0) { return SM90::GMMA::SPARSE::GMMA_64x160x64_F16E4M3E5M2_SS_TN{}; } #endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 152 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x152x64_F16E4M3E5M2_SS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 144 == 0) { return SM90::GMMA::SPARSE::GMMA_64x144x64_F16E4M3E5M2_SS_TN{}; } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 136 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x136x64_F16E4M3E5M2_SS_TN{}; + } #endif else if constexpr (Tile_N % 128 == 0) { return SM90::GMMA::SPARSE::GMMA_64x128x64_F16E4M3E5M2_SS_TN{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 120 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x120x64_F16E4M3E5M2_SS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 112 == 0) { return SM90::GMMA::SPARSE::GMMA_64x112x64_F16E4M3E5M2_SS_TN{}; } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 104 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x104x64_F16E4M3E5M2_SS_TN{}; + } #endif else if constexpr (Tile_N % 96 == 0) { return SM90::GMMA::SPARSE::GMMA_64x96x64_F16E4M3E5M2_SS_TN{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 88 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x88x64_F16E4M3E5M2_SS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 80 == 0) { return SM90::GMMA::SPARSE::GMMA_64x80x64_F16E4M3E5M2_SS_TN{}; } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 72 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x72x64_F16E4M3E5M2_SS_TN{}; + } #endif else if constexpr (Tile_N % 64 == 0) { return SM90::GMMA::SPARSE::GMMA_64x64x64_F16E4M3E5M2_SS_TN{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 56 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x56x64_F16E4M3E5M2_SS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 48 == 0) { return SM90::GMMA::SPARSE::GMMA_64x48x64_F16E4M3E5M2_SS_TN{}; } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 40 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x40x64_F16E4M3E5M2_SS_TN{}; + } #endif else if constexpr (Tile_N % 32 == 0) { return SM90::GMMA::SPARSE::GMMA_64x32x64_F16E4M3E5M2_SS_TN{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 24 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x24x64_F16E4M3E5M2_SS_TN{}; + } +#endif else if constexpr (Tile_N % 16 == 0) { return SM90::GMMA::SPARSE::GMMA_64x16x64_F16E4M3E5M2_SS_TN{}; } @@ -1953,66 +3088,141 @@ ss_op_selector_sparse() if constexpr (Tile_N % 256 == 0) { return SM90::GMMA::SPARSE::GMMA_64x256x64_F16E5M2E4M3_SS_TN{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 248 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x248x64_F16E5M2E4M3_SS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 240 == 0) { return SM90::GMMA::SPARSE::GMMA_64x240x64_F16E5M2E4M3_SS_TN{}; } #endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 232 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x232x64_F16E5M2E4M3_SS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 224 == 0) { return SM90::GMMA::SPARSE::GMMA_64x224x64_F16E5M2E4M3_SS_TN{}; } #endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 216 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x216x64_F16E5M2E4M3_SS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 208 == 0) { return SM90::GMMA::SPARSE::GMMA_64x208x64_F16E5M2E4M3_SS_TN{}; } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 200 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x200x64_F16E5M2E4M3_SS_TN{}; + } #endif else if constexpr (Tile_N % 192 == 0) { return SM90::GMMA::SPARSE::GMMA_64x192x64_F16E5M2E4M3_SS_TN{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 184 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x184x64_F16E5M2E4M3_SS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 176 == 0) { return SM90::GMMA::SPARSE::GMMA_64x176x64_F16E5M2E4M3_SS_TN{}; } #endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 168 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x168x64_F16E5M2E4M3_SS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 160 == 0) { return SM90::GMMA::SPARSE::GMMA_64x160x64_F16E5M2E4M3_SS_TN{}; } #endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 152 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x152x64_F16E5M2E4M3_SS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 144 == 0) { return SM90::GMMA::SPARSE::GMMA_64x144x64_F16E5M2E4M3_SS_TN{}; } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 136 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x136x64_F16E5M2E4M3_SS_TN{}; + } #endif else if constexpr (Tile_N % 128 == 0) { return SM90::GMMA::SPARSE::GMMA_64x128x64_F16E5M2E4M3_SS_TN{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 120 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x120x64_F16E5M2E4M3_SS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 112 == 0) { return SM90::GMMA::SPARSE::GMMA_64x112x64_F16E5M2E4M3_SS_TN{}; } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 104 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x104x64_F16E5M2E4M3_SS_TN{}; + } #endif else if constexpr (Tile_N % 96 == 0) { return SM90::GMMA::SPARSE::GMMA_64x96x64_F16E5M2E4M3_SS_TN{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 88 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x88x64_F16E5M2E4M3_SS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 80 == 0) { return SM90::GMMA::SPARSE::GMMA_64x80x64_F16E5M2E4M3_SS_TN{}; } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 72 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x72x64_F16E5M2E4M3_SS_TN{}; + } #endif else if constexpr (Tile_N % 64 == 0) { return SM90::GMMA::SPARSE::GMMA_64x64x64_F16E5M2E4M3_SS_TN{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 56 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x56x64_F16E5M2E4M3_SS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 48 == 0) { return SM90::GMMA::SPARSE::GMMA_64x48x64_F16E5M2E4M3_SS_TN{}; } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 40 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x40x64_F16E5M2E4M3_SS_TN{}; + } #endif else if constexpr (Tile_N % 32 == 0) { return SM90::GMMA::SPARSE::GMMA_64x32x64_F16E5M2E4M3_SS_TN{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 24 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x24x64_F16E5M2E4M3_SS_TN{}; + } +#endif else if constexpr (Tile_N % 16 == 0) { return SM90::GMMA::SPARSE::GMMA_64x16x64_F16E5M2E4M3_SS_TN{}; } @@ -2033,66 +3243,141 @@ ss_op_selector_sparse() if constexpr (Tile_N % 256 == 0) { return SM90::GMMA::SPARSE::GMMA_64x256x64_F16E5M2E5M2_SS_TN{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 248 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x248x64_F16E5M2E5M2_SS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 240 == 0) { return SM90::GMMA::SPARSE::GMMA_64x240x64_F16E5M2E5M2_SS_TN{}; } #endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 232 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x232x64_F16E5M2E5M2_SS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 224 == 0) { return SM90::GMMA::SPARSE::GMMA_64x224x64_F16E5M2E5M2_SS_TN{}; } #endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 216 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x216x64_F16E5M2E5M2_SS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 208 == 0) { return SM90::GMMA::SPARSE::GMMA_64x208x64_F16E5M2E5M2_SS_TN{}; } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 200 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x200x64_F16E5M2E5M2_SS_TN{}; + } #endif else if constexpr (Tile_N % 192 == 0) { return SM90::GMMA::SPARSE::GMMA_64x192x64_F16E5M2E5M2_SS_TN{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 184 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x184x64_F16E5M2E5M2_SS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 176 == 0) { return SM90::GMMA::SPARSE::GMMA_64x176x64_F16E5M2E5M2_SS_TN{}; } #endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 168 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x168x64_F16E5M2E5M2_SS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 160 == 0) { return SM90::GMMA::SPARSE::GMMA_64x160x64_F16E5M2E5M2_SS_TN{}; } #endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 152 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x152x64_F16E5M2E5M2_SS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 144 == 0) { return SM90::GMMA::SPARSE::GMMA_64x144x64_F16E5M2E5M2_SS_TN{}; } #endif - else if constexpr (Tile_N % 128 == 0) { +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 136 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x136x64_F16E5M2E5M2_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 128 == 0) { return SM90::GMMA::SPARSE::GMMA_64x128x64_F16E5M2E5M2_SS_TN{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 120 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x120x64_F16E5M2E5M2_SS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 112 == 0) { return SM90::GMMA::SPARSE::GMMA_64x112x64_F16E5M2E5M2_SS_TN{}; } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 104 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x104x64_F16E5M2E5M2_SS_TN{}; + } #endif else if constexpr (Tile_N % 96 == 0) { return SM90::GMMA::SPARSE::GMMA_64x96x64_F16E5M2E5M2_SS_TN{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 88 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x88x64_F16E5M2E5M2_SS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 80 == 0) { return SM90::GMMA::SPARSE::GMMA_64x80x64_F16E5M2E5M2_SS_TN{}; } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 72 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x72x64_F16E5M2E5M2_SS_TN{}; + } #endif else if constexpr (Tile_N % 64 == 0) { return SM90::GMMA::SPARSE::GMMA_64x64x64_F16E5M2E5M2_SS_TN{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 56 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x56x64_F16E5M2E5M2_SS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 48 == 0) { return SM90::GMMA::SPARSE::GMMA_64x48x64_F16E5M2E5M2_SS_TN{}; } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 40 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x40x64_F16E5M2E5M2_SS_TN{}; + } #endif else if constexpr (Tile_N % 32 == 0) { return SM90::GMMA::SPARSE::GMMA_64x32x64_F16E5M2E5M2_SS_TN{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 24 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x24x64_F16E5M2E5M2_SS_TN{}; + } +#endif else if constexpr (Tile_N % 16 == 0) { return SM90::GMMA::SPARSE::GMMA_64x16x64_F16E5M2E5M2_SS_TN{}; } @@ -2119,66 +3404,141 @@ ss_op_selector_sparse() if constexpr (Tile_N % 256 == 0) { return SM90::GMMA::SPARSE::GMMA_64x256x32_F32F16F16_SS{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 248 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x248x32_F32F16F16_SS{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 240 == 0) { return SM90::GMMA::SPARSE::GMMA_64x240x32_F32F16F16_SS{}; } #endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 232 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x232x32_F32F16F16_SS{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 224 == 0) { return SM90::GMMA::SPARSE::GMMA_64x224x32_F32F16F16_SS{}; } #endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 216 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x216x32_F32F16F16_SS{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 208 == 0) { return SM90::GMMA::SPARSE::GMMA_64x208x32_F32F16F16_SS{}; } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 200 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x200x32_F32F16F16_SS{}; + } #endif else if constexpr (Tile_N % 192 == 0) { return SM90::GMMA::SPARSE::GMMA_64x192x32_F32F16F16_SS{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 184 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x184x32_F32F16F16_SS{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 176 == 0) { return SM90::GMMA::SPARSE::GMMA_64x176x32_F32F16F16_SS{}; } #endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 168 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x168x32_F32F16F16_SS{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 160 == 0) { return SM90::GMMA::SPARSE::GMMA_64x160x32_F32F16F16_SS{}; } #endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 152 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x152x32_F32F16F16_SS{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 144 == 0) { return SM90::GMMA::SPARSE::GMMA_64x144x32_F32F16F16_SS{}; } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 136 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x136x32_F32F16F16_SS{}; + } #endif else if constexpr (Tile_N % 128 == 0) { return SM90::GMMA::SPARSE::GMMA_64x128x32_F32F16F16_SS{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 120 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x120x32_F32F16F16_SS{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 112 == 0) { return SM90::GMMA::SPARSE::GMMA_64x112x32_F32F16F16_SS{}; } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 104 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x104x32_F32F16F16_SS{}; + } #endif else if constexpr (Tile_N % 96 == 0) { return SM90::GMMA::SPARSE::GMMA_64x96x32_F32F16F16_SS{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 88 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x88x32_F32F16F16_SS{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 80 == 0) { return SM90::GMMA::SPARSE::GMMA_64x80x32_F32F16F16_SS{}; } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 72 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x72x32_F32F16F16_SS{}; + } #endif else if constexpr (Tile_N % 64 == 0) { return SM90::GMMA::SPARSE::GMMA_64x64x32_F32F16F16_SS{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 56 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x56x32_F32F16F16_SS{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 48 == 0) { return SM90::GMMA::SPARSE::GMMA_64x48x32_F32F16F16_SS{}; } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 40 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x40x32_F32F16F16_SS{}; + } #endif else if constexpr (Tile_N % 32 == 0) { return SM90::GMMA::SPARSE::GMMA_64x32x32_F32F16F16_SS{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 24 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x24x32_F32F16F16_SS{}; + } +#endif else if constexpr (Tile_N % 16 == 0) { return SM90::GMMA::SPARSE::GMMA_64x16x32_F32F16F16_SS{}; } @@ -2197,66 +3557,141 @@ ss_op_selector_sparse() if constexpr (Tile_N % 256 == 0) { return SM90::GMMA::SPARSE::GMMA_64x256x32_F32BF16BF16_SS{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 248 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x248x32_F32BF16BF16_SS{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 240 == 0) { return SM90::GMMA::SPARSE::GMMA_64x240x32_F32BF16BF16_SS{}; } #endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 232 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x232x32_F32BF16BF16_SS{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 224 == 0) { return SM90::GMMA::SPARSE::GMMA_64x224x32_F32BF16BF16_SS{}; } #endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 216 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x216x32_F32BF16BF16_SS{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 208 == 0) { return SM90::GMMA::SPARSE::GMMA_64x208x32_F32BF16BF16_SS{}; } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 200 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x200x32_F32BF16BF16_SS{}; + } #endif else if constexpr (Tile_N % 192 == 0) { return SM90::GMMA::SPARSE::GMMA_64x192x32_F32BF16BF16_SS{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 184 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x184x32_F32BF16BF16_SS{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 176 == 0) { return SM90::GMMA::SPARSE::GMMA_64x176x32_F32BF16BF16_SS{}; } #endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 168 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x168x32_F32BF16BF16_SS{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 160 == 0) { return SM90::GMMA::SPARSE::GMMA_64x160x32_F32BF16BF16_SS{}; } #endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 152 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x152x32_F32BF16BF16_SS{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 144 == 0) { return SM90::GMMA::SPARSE::GMMA_64x144x32_F32BF16BF16_SS{}; } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 136 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x136x32_F32BF16BF16_SS{}; + } #endif else if constexpr (Tile_N % 128 == 0) { return SM90::GMMA::SPARSE::GMMA_64x128x32_F32BF16BF16_SS{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 120 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x120x32_F32BF16BF16_SS{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 112 == 0) { return SM90::GMMA::SPARSE::GMMA_64x112x32_F32BF16BF16_SS{}; } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 104 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x104x32_F32BF16BF16_SS{}; + } #endif else if constexpr (Tile_N % 96 == 0) { return SM90::GMMA::SPARSE::GMMA_64x96x32_F32BF16BF16_SS{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 88 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x88x32_F32BF16BF16_SS{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 80 == 0) { return SM90::GMMA::SPARSE::GMMA_64x80x32_F32BF16BF16_SS{}; } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 72 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x72x32_F32BF16BF16_SS{}; + } #endif else if constexpr (Tile_N % 64 == 0) { return SM90::GMMA::SPARSE::GMMA_64x64x32_F32BF16BF16_SS{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 56 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x56x32_F32BF16BF16_SS{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 48 == 0) { return SM90::GMMA::SPARSE::GMMA_64x48x32_F32BF16BF16_SS{}; } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 40 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x40x32_F32BF16BF16_SS{}; + } #endif else if constexpr (Tile_N % 32 == 0) { return SM90::GMMA::SPARSE::GMMA_64x32x32_F32BF16BF16_SS{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 24 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x24x32_F32BF16BF16_SS{}; + } +#endif else if constexpr (Tile_N % 16 == 0) { return SM90::GMMA::SPARSE::GMMA_64x16x32_F32BF16BF16_SS{}; } @@ -2277,66 +3712,141 @@ ss_op_selector_sparse() if constexpr (Tile_N % 256 == 0) { return SM90::GMMA::SPARSE::GMMA_64x256x16_F32TF32TF32_SS_TN{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 248 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x248x16_F32TF32TF32_SS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 240 == 0) { return SM90::GMMA::SPARSE::GMMA_64x240x16_F32TF32TF32_SS_TN{}; } #endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 232 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x232x16_F32TF32TF32_SS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 224 == 0) { return SM90::GMMA::SPARSE::GMMA_64x224x16_F32TF32TF32_SS_TN{}; } #endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 216 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x216x16_F32TF32TF32_SS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 208 == 0) { return SM90::GMMA::SPARSE::GMMA_64x208x16_F32TF32TF32_SS_TN{}; } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 200 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x200x16_F32TF32TF32_SS_TN{}; + } #endif else if constexpr (Tile_N % 192 == 0) { return SM90::GMMA::SPARSE::GMMA_64x192x16_F32TF32TF32_SS_TN{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 184 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x184x16_F32TF32TF32_SS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 176 == 0) { return SM90::GMMA::SPARSE::GMMA_64x176x16_F32TF32TF32_SS_TN{}; } #endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 168 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x168x16_F32TF32TF32_SS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 160 == 0) { return SM90::GMMA::SPARSE::GMMA_64x160x16_F32TF32TF32_SS_TN{}; } #endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 152 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x152x16_F32TF32TF32_SS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 144 == 0) { return SM90::GMMA::SPARSE::GMMA_64x144x16_F32TF32TF32_SS_TN{}; } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 136 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x136x16_F32TF32TF32_SS_TN{}; + } #endif else if constexpr (Tile_N % 128 == 0) { return SM90::GMMA::SPARSE::GMMA_64x128x16_F32TF32TF32_SS_TN{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 120 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x120x16_F32TF32TF32_SS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 112 == 0) { return SM90::GMMA::SPARSE::GMMA_64x112x16_F32TF32TF32_SS_TN{}; } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 104 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x104x16_F32TF32TF32_SS_TN{}; + } #endif else if constexpr (Tile_N % 96 == 0) { return SM90::GMMA::SPARSE::GMMA_64x96x16_F32TF32TF32_SS_TN{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 88 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x88x16_F32TF32TF32_SS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 80 == 0) { return SM90::GMMA::SPARSE::GMMA_64x80x16_F32TF32TF32_SS_TN{}; } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 72 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x72x16_F32TF32TF32_SS_TN{}; + } #endif else if constexpr (Tile_N % 64 == 0) { return SM90::GMMA::SPARSE::GMMA_64x64x16_F32TF32TF32_SS_TN{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 56 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x56x16_F32TF32TF32_SS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 48 == 0) { return SM90::GMMA::SPARSE::GMMA_64x48x16_F32TF32TF32_SS_TN{}; } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 40 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x40x16_F32TF32TF32_SS_TN{}; + } #endif else if constexpr (Tile_N % 32 == 0) { return SM90::GMMA::SPARSE::GMMA_64x32x16_F32TF32TF32_SS_TN{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 24 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x24x16_F32TF32TF32_SS_TN{}; + } +#endif else if constexpr (Tile_N % 16 == 0) { return SM90::GMMA::SPARSE::GMMA_64x16x16_F32TF32TF32_SS_TN{}; } @@ -2357,66 +3867,141 @@ ss_op_selector_sparse() if constexpr (Tile_N % 256 == 0) { return SM90::GMMA::SPARSE::GMMA_64x256x64_F32E4M3E4M3_SS_TN{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 248 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x248x64_F32E4M3E4M3_SS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 240 == 0) { return SM90::GMMA::SPARSE::GMMA_64x240x64_F32E4M3E4M3_SS_TN{}; } #endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 232 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x232x64_F32E4M3E4M3_SS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 224 == 0) { return SM90::GMMA::SPARSE::GMMA_64x224x64_F32E4M3E4M3_SS_TN{}; } #endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 216 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x216x64_F32E4M3E4M3_SS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 208 == 0) { return SM90::GMMA::SPARSE::GMMA_64x208x64_F32E4M3E4M3_SS_TN{}; } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 200 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x200x64_F32E4M3E4M3_SS_TN{}; + } #endif else if constexpr (Tile_N % 192 == 0) { return SM90::GMMA::SPARSE::GMMA_64x192x64_F32E4M3E4M3_SS_TN{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 184 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x184x64_F32E4M3E4M3_SS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 176 == 0) { return SM90::GMMA::SPARSE::GMMA_64x176x64_F32E4M3E4M3_SS_TN{}; } #endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 168 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x168x64_F32E4M3E4M3_SS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 160 == 0) { return SM90::GMMA::SPARSE::GMMA_64x160x64_F32E4M3E4M3_SS_TN{}; } #endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 152 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x152x64_F32E4M3E4M3_SS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 144 == 0) { return SM90::GMMA::SPARSE::GMMA_64x144x64_F32E4M3E4M3_SS_TN{}; } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 136 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x136x64_F32E4M3E4M3_SS_TN{}; + } #endif else if constexpr (Tile_N % 128 == 0) { return SM90::GMMA::SPARSE::GMMA_64x128x64_F32E4M3E4M3_SS_TN{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 120 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x120x64_F32E4M3E4M3_SS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 112 == 0) { return SM90::GMMA::SPARSE::GMMA_64x112x64_F32E4M3E4M3_SS_TN{}; } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 104 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x104x64_F32E4M3E4M3_SS_TN{}; + } #endif else if constexpr (Tile_N % 96 == 0) { return SM90::GMMA::SPARSE::GMMA_64x96x64_F32E4M3E4M3_SS_TN{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 88 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x88x64_F32E4M3E4M3_SS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 80 == 0) { return SM90::GMMA::SPARSE::GMMA_64x80x64_F32E4M3E4M3_SS_TN{}; } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 72 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x72x64_F32E4M3E4M3_SS_TN{}; + } #endif else if constexpr (Tile_N % 64 == 0) { return SM90::GMMA::SPARSE::GMMA_64x64x64_F32E4M3E4M3_SS_TN{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 56 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x56x64_F32E4M3E4M3_SS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 48 == 0) { return SM90::GMMA::SPARSE::GMMA_64x48x64_F32E4M3E4M3_SS_TN{}; } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 40 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x40x64_F32E4M3E4M3_SS_TN{}; + } #endif else if constexpr (Tile_N % 32 == 0) { return SM90::GMMA::SPARSE::GMMA_64x32x64_F32E4M3E4M3_SS_TN{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 24 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x24x64_F32E4M3E4M3_SS_TN{}; + } +#endif else if constexpr (Tile_N % 16 == 0) { return SM90::GMMA::SPARSE::GMMA_64x16x64_F32E4M3E4M3_SS_TN{}; } @@ -2437,66 +4022,141 @@ ss_op_selector_sparse() if constexpr (Tile_N % 256 == 0) { return SM90::GMMA::SPARSE::GMMA_64x256x64_F32E4M3E5M2_SS_TN{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 248 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x248x64_F32E4M3E5M2_SS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 240 == 0) { return SM90::GMMA::SPARSE::GMMA_64x240x64_F32E4M3E5M2_SS_TN{}; } #endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 232 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x232x64_F32E4M3E5M2_SS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 224 == 0) { return SM90::GMMA::SPARSE::GMMA_64x224x64_F32E4M3E5M2_SS_TN{}; } #endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 216 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x216x64_F32E4M3E5M2_SS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 208 == 0) { return SM90::GMMA::SPARSE::GMMA_64x208x64_F32E4M3E5M2_SS_TN{}; } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 200 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x200x64_F32E4M3E5M2_SS_TN{}; + } #endif else if constexpr (Tile_N % 192 == 0) { return SM90::GMMA::SPARSE::GMMA_64x192x64_F32E4M3E5M2_SS_TN{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 184 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x184x64_F32E4M3E5M2_SS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 176 == 0) { return SM90::GMMA::SPARSE::GMMA_64x176x64_F32E4M3E5M2_SS_TN{}; } #endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 168 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x168x64_F32E4M3E5M2_SS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 160 == 0) { return SM90::GMMA::SPARSE::GMMA_64x160x64_F32E4M3E5M2_SS_TN{}; } #endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 152 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x152x64_F32E4M3E5M2_SS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 144 == 0) { return SM90::GMMA::SPARSE::GMMA_64x144x64_F32E4M3E5M2_SS_TN{}; } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 136 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x136x64_F32E4M3E5M2_SS_TN{}; + } #endif else if constexpr (Tile_N % 128 == 0) { return SM90::GMMA::SPARSE::GMMA_64x128x64_F32E4M3E5M2_SS_TN{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 120 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x120x64_F32E4M3E5M2_SS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 112 == 0) { return SM90::GMMA::SPARSE::GMMA_64x112x64_F32E4M3E5M2_SS_TN{}; } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 104 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x104x64_F32E4M3E5M2_SS_TN{}; + } #endif else if constexpr (Tile_N % 96 == 0) { return SM90::GMMA::SPARSE::GMMA_64x96x64_F32E4M3E5M2_SS_TN{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 88 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x88x64_F32E4M3E5M2_SS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 80 == 0) { return SM90::GMMA::SPARSE::GMMA_64x80x64_F32E4M3E5M2_SS_TN{}; } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 72 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x72x64_F32E4M3E5M2_SS_TN{}; + } #endif else if constexpr (Tile_N % 64 == 0) { return SM90::GMMA::SPARSE::GMMA_64x64x64_F32E4M3E5M2_SS_TN{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 56 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x56x64_F32E4M3E5M2_SS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 48 == 0) { return SM90::GMMA::SPARSE::GMMA_64x48x64_F32E4M3E5M2_SS_TN{}; } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 40 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x40x64_F32E4M3E5M2_SS_TN{}; + } #endif else if constexpr (Tile_N % 32 == 0) { return SM90::GMMA::SPARSE::GMMA_64x32x64_F32E4M3E5M2_SS_TN{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 24 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x24x64_F32E4M3E5M2_SS_TN{}; + } +#endif else if constexpr (Tile_N % 16 == 0) { return SM90::GMMA::SPARSE::GMMA_64x16x64_F32E4M3E5M2_SS_TN{}; } @@ -2517,66 +4177,141 @@ ss_op_selector_sparse() if constexpr (Tile_N % 256 == 0) { return SM90::GMMA::SPARSE::GMMA_64x256x64_F32E5M2E4M3_SS_TN{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 248 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x248x64_F32E5M2E4M3_SS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 240 == 0) { return SM90::GMMA::SPARSE::GMMA_64x240x64_F32E5M2E4M3_SS_TN{}; } #endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 232 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x232x64_F32E5M2E4M3_SS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 224 == 0) { return SM90::GMMA::SPARSE::GMMA_64x224x64_F32E5M2E4M3_SS_TN{}; } #endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 216 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x216x64_F32E5M2E4M3_SS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 208 == 0) { return SM90::GMMA::SPARSE::GMMA_64x208x64_F32E5M2E4M3_SS_TN{}; } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 200 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x200x64_F32E5M2E4M3_SS_TN{}; + } #endif else if constexpr (Tile_N % 192 == 0) { return SM90::GMMA::SPARSE::GMMA_64x192x64_F32E5M2E4M3_SS_TN{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 184 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x184x64_F32E5M2E4M3_SS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 176 == 0) { return SM90::GMMA::SPARSE::GMMA_64x176x64_F32E5M2E4M3_SS_TN{}; } #endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 168 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x168x64_F32E5M2E4M3_SS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 160 == 0) { return SM90::GMMA::SPARSE::GMMA_64x160x64_F32E5M2E4M3_SS_TN{}; } #endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 152 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x152x64_F32E5M2E4M3_SS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 144 == 0) { return SM90::GMMA::SPARSE::GMMA_64x144x64_F32E5M2E4M3_SS_TN{}; } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 136 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x136x64_F32E5M2E4M3_SS_TN{}; + } #endif else if constexpr (Tile_N % 128 == 0) { return SM90::GMMA::SPARSE::GMMA_64x128x64_F32E5M2E4M3_SS_TN{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 120 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x120x64_F32E5M2E4M3_SS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 112 == 0) { return SM90::GMMA::SPARSE::GMMA_64x112x64_F32E5M2E4M3_SS_TN{}; } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 104 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x104x64_F32E5M2E4M3_SS_TN{}; + } #endif else if constexpr (Tile_N % 96 == 0) { return SM90::GMMA::SPARSE::GMMA_64x96x64_F32E5M2E4M3_SS_TN{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 88 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x88x64_F32E5M2E4M3_SS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 80 == 0) { return SM90::GMMA::SPARSE::GMMA_64x80x64_F32E5M2E4M3_SS_TN{}; } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 72 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x72x64_F32E5M2E4M3_SS_TN{}; + } #endif else if constexpr (Tile_N % 64 == 0) { return SM90::GMMA::SPARSE::GMMA_64x64x64_F32E5M2E4M3_SS_TN{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 56 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x56x64_F32E5M2E4M3_SS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 48 == 0) { return SM90::GMMA::SPARSE::GMMA_64x48x64_F32E5M2E4M3_SS_TN{}; } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 40 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x40x64_F32E5M2E4M3_SS_TN{}; + } #endif else if constexpr (Tile_N % 32 == 0) { return SM90::GMMA::SPARSE::GMMA_64x32x64_F32E5M2E4M3_SS_TN{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 24 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x24x64_F32E5M2E4M3_SS_TN{}; + } +#endif else if constexpr (Tile_N % 16 == 0) { return SM90::GMMA::SPARSE::GMMA_64x16x64_F32E5M2E4M3_SS_TN{}; } @@ -2597,66 +4332,141 @@ ss_op_selector_sparse() if constexpr (Tile_N % 256 == 0) { return SM90::GMMA::SPARSE::GMMA_64x256x64_F32E5M2E5M2_SS_TN{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 248 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x248x64_F32E5M2E5M2_SS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 240 == 0) { return SM90::GMMA::SPARSE::GMMA_64x240x64_F32E5M2E5M2_SS_TN{}; } #endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 232 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x232x64_F32E5M2E5M2_SS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 224 == 0) { return SM90::GMMA::SPARSE::GMMA_64x224x64_F32E5M2E5M2_SS_TN{}; } #endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 216 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x216x64_F32E5M2E5M2_SS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 208 == 0) { return SM90::GMMA::SPARSE::GMMA_64x208x64_F32E5M2E5M2_SS_TN{}; } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 200 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x200x64_F32E5M2E5M2_SS_TN{}; + } #endif else if constexpr (Tile_N % 192 == 0) { return SM90::GMMA::SPARSE::GMMA_64x192x64_F32E5M2E5M2_SS_TN{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 184 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x184x64_F32E5M2E5M2_SS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 176 == 0) { return SM90::GMMA::SPARSE::GMMA_64x176x64_F32E5M2E5M2_SS_TN{}; } #endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 168 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x168x64_F32E5M2E5M2_SS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 160 == 0) { return SM90::GMMA::SPARSE::GMMA_64x160x64_F32E5M2E5M2_SS_TN{}; } #endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 152 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x152x64_F32E5M2E5M2_SS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 144 == 0) { return SM90::GMMA::SPARSE::GMMA_64x144x64_F32E5M2E5M2_SS_TN{}; } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 136 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x136x64_F32E5M2E5M2_SS_TN{}; + } #endif else if constexpr (Tile_N % 128 == 0) { return SM90::GMMA::SPARSE::GMMA_64x128x64_F32E5M2E5M2_SS_TN{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 120 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x120x64_F32E5M2E5M2_SS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 112 == 0) { return SM90::GMMA::SPARSE::GMMA_64x112x64_F32E5M2E5M2_SS_TN{}; } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 104 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x104x64_F32E5M2E5M2_SS_TN{}; + } #endif else if constexpr (Tile_N % 96 == 0) { return SM90::GMMA::SPARSE::GMMA_64x96x64_F32E5M2E5M2_SS_TN{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 88 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x88x64_F32E5M2E5M2_SS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 80 == 0) { return SM90::GMMA::SPARSE::GMMA_64x80x64_F32E5M2E5M2_SS_TN{}; } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 72 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x72x64_F32E5M2E5M2_SS_TN{}; + } #endif else if constexpr (Tile_N % 64 == 0) { return SM90::GMMA::SPARSE::GMMA_64x64x64_F32E5M2E5M2_SS_TN{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 56 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x56x64_F32E5M2E5M2_SS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 48 == 0) { return SM90::GMMA::SPARSE::GMMA_64x48x64_F32E5M2E5M2_SS_TN{}; } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 40 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x40x64_F32E5M2E5M2_SS_TN{}; + } #endif else if constexpr (Tile_N % 32 == 0) { return SM90::GMMA::SPARSE::GMMA_64x32x64_F32E5M2E5M2_SS_TN{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 24 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x24x64_F32E5M2E5M2_SS_TN{}; + } +#endif else if constexpr (Tile_N % 16 == 0) { return SM90::GMMA::SPARSE::GMMA_64x16x64_F32E5M2E5M2_SS_TN{}; } @@ -2745,6 +4555,11 @@ ss_op_selector_sparse() else if constexpr (Tile_N % 32 == 0) { return SM90::GMMA::SPARSE::GMMA_64x32x64_S32S8S8_SS_TN{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 24 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x24x64_S32S8S8_SS_TN{}; + } +#endif else if constexpr (Tile_N % 16 == 0) { return SM90::GMMA::SPARSE::GMMA_64x16x64_S32S8S8_SS_TN{}; } @@ -2825,6 +4640,11 @@ ss_op_selector_sparse() else if constexpr (Tile_N % 32 == 0) { return SM90::GMMA::SPARSE::GMMA_64x32x64_S32S8U8_SS_TN{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 24 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x24x64_S32S8U8_SS_TN{}; + } +#endif else if constexpr (Tile_N % 16 == 0) { return SM90::GMMA::SPARSE::GMMA_64x16x64_S32S8U8_SS_TN{}; } @@ -2905,6 +4725,11 @@ ss_op_selector_sparse() else if constexpr (Tile_N % 32 == 0) { return SM90::GMMA::SPARSE::GMMA_64x32x64_S32U8S8_SS_TN{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 24 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x24x64_S32U8S8_SS_TN{}; + } +#endif else if constexpr (Tile_N % 16 == 0) { return SM90::GMMA::SPARSE::GMMA_64x16x64_S32U8S8_SS_TN{}; } @@ -2985,6 +4810,11 @@ ss_op_selector_sparse() else if constexpr (Tile_N % 32 == 0) { return SM90::GMMA::SPARSE::GMMA_64x32x64_S32U8U8_SS_TN{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 24 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x24x64_S32U8U8_SS_TN{}; + } +#endif else if constexpr (Tile_N % 16 == 0) { return SM90::GMMA::SPARSE::GMMA_64x16x64_S32U8U8_SS_TN{}; } @@ -3037,66 +4867,141 @@ rs_op_selector() if constexpr (Tile_N % 256 == 0) { return SM90::GMMA::MMA_64x256x16_F16F16F16_RS{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 248 == 0) { + return SM90::GMMA::MMA_64x248x16_F16F16F16_RS{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 240 == 0) { return SM90::GMMA::MMA_64x240x16_F16F16F16_RS{}; } #endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 232 == 0) { + return SM90::GMMA::MMA_64x232x16_F16F16F16_RS{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 224 == 0) { return SM90::GMMA::MMA_64x224x16_F16F16F16_RS{}; } #endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 216 == 0) { + return SM90::GMMA::MMA_64x216x16_F16F16F16_RS{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 208 == 0) { return SM90::GMMA::MMA_64x208x16_F16F16F16_RS{}; } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 200 == 0) { + return SM90::GMMA::MMA_64x200x16_F16F16F16_RS{}; + } #endif else if constexpr (Tile_N % 192 == 0) { return SM90::GMMA::MMA_64x192x16_F16F16F16_RS{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 184 == 0) { + return SM90::GMMA::MMA_64x184x16_F16F16F16_RS{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 176 == 0) { return SM90::GMMA::MMA_64x176x16_F16F16F16_RS{}; } #endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 168 == 0) { + return SM90::GMMA::MMA_64x168x16_F16F16F16_RS{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 160 == 0) { return SM90::GMMA::MMA_64x160x16_F16F16F16_RS{}; } #endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 152 == 0) { + return SM90::GMMA::MMA_64x152x16_F16F16F16_RS{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 144 == 0) { return SM90::GMMA::MMA_64x144x16_F16F16F16_RS{}; } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 136 == 0) { + return SM90::GMMA::MMA_64x136x16_F16F16F16_RS{}; + } #endif else if constexpr (Tile_N % 128 == 0) { return SM90::GMMA::MMA_64x128x16_F16F16F16_RS{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 120 == 0) { + return SM90::GMMA::MMA_64x120x16_F16F16F16_RS{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 112 == 0) { return SM90::GMMA::MMA_64x112x16_F16F16F16_RS{}; } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 104 == 0) { + return SM90::GMMA::MMA_64x104x16_F16F16F16_RS{}; + } #endif else if constexpr (Tile_N % 96 == 0) { return SM90::GMMA::MMA_64x96x16_F16F16F16_RS{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 88 == 0) { + return SM90::GMMA::MMA_64x88x16_F16F16F16_RS{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 80 == 0) { return SM90::GMMA::MMA_64x80x16_F16F16F16_RS{}; } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 72 == 0) { + return SM90::GMMA::MMA_64x72x16_F16F16F16_RS{}; + } #endif else if constexpr (Tile_N % 64 == 0) { return SM90::GMMA::MMA_64x64x16_F16F16F16_RS{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 56 == 0) { + return SM90::GMMA::MMA_64x56x16_F16F16F16_RS{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 48 == 0) { return SM90::GMMA::MMA_64x48x16_F16F16F16_RS{}; } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 40 == 0) { + return SM90::GMMA::MMA_64x40x16_F16F16F16_RS{}; + } #endif else if constexpr (Tile_N % 32 == 0) { return SM90::GMMA::MMA_64x32x16_F16F16F16_RS{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 24 == 0) { + return SM90::GMMA::MMA_64x24x16_F16F16F16_RS{}; + } +#endif else if constexpr (Tile_N % 16 == 0) { return SM90::GMMA::MMA_64x16x16_F16F16F16_RS{}; } @@ -3117,66 +5022,141 @@ rs_op_selector() if constexpr (Tile_N % 256 == 0) { return SM90::GMMA::MMA_64x256x32_F16E4M3E4M3_RS_TN{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 248 == 0) { + return SM90::GMMA::MMA_64x248x32_F16E4M3E4M3_RS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 240 == 0) { return SM90::GMMA::MMA_64x240x32_F16E4M3E4M3_RS_TN{}; } #endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 232 == 0) { + return SM90::GMMA::MMA_64x232x32_F16E4M3E4M3_RS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 224 == 0) { return SM90::GMMA::MMA_64x224x32_F16E4M3E4M3_RS_TN{}; } #endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 216 == 0) { + return SM90::GMMA::MMA_64x216x32_F16E4M3E4M3_RS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 208 == 0) { return SM90::GMMA::MMA_64x208x32_F16E4M3E4M3_RS_TN{}; } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 200 == 0) { + return SM90::GMMA::MMA_64x200x32_F16E4M3E4M3_RS_TN{}; + } #endif else if constexpr (Tile_N % 192 == 0) { return SM90::GMMA::MMA_64x192x32_F16E4M3E4M3_RS_TN{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 184 == 0) { + return SM90::GMMA::MMA_64x184x32_F16E4M3E4M3_RS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 176 == 0) { return SM90::GMMA::MMA_64x176x32_F16E4M3E4M3_RS_TN{}; } #endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 168 == 0) { + return SM90::GMMA::MMA_64x168x32_F16E4M3E4M3_RS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 160 == 0) { return SM90::GMMA::MMA_64x160x32_F16E4M3E4M3_RS_TN{}; } #endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 152 == 0) { + return SM90::GMMA::MMA_64x152x32_F16E4M3E4M3_RS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 144 == 0) { return SM90::GMMA::MMA_64x144x32_F16E4M3E4M3_RS_TN{}; } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 136 == 0) { + return SM90::GMMA::MMA_64x136x32_F16E4M3E4M3_RS_TN{}; + } #endif else if constexpr (Tile_N % 128 == 0) { return SM90::GMMA::MMA_64x128x32_F16E4M3E4M3_RS_TN{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 120 == 0) { + return SM90::GMMA::MMA_64x120x32_F16E4M3E4M3_RS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 112 == 0) { return SM90::GMMA::MMA_64x112x32_F16E4M3E4M3_RS_TN{}; } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 104 == 0) { + return SM90::GMMA::MMA_64x104x32_F16E4M3E4M3_RS_TN{}; + } #endif else if constexpr (Tile_N % 96 == 0) { return SM90::GMMA::MMA_64x96x32_F16E4M3E4M3_RS_TN{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 88 == 0) { + return SM90::GMMA::MMA_64x88x32_F16E4M3E4M3_RS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 80 == 0) { return SM90::GMMA::MMA_64x80x32_F16E4M3E4M3_RS_TN{}; } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 72 == 0) { + return SM90::GMMA::MMA_64x72x32_F16E4M3E4M3_RS_TN{}; + } #endif else if constexpr (Tile_N % 64 == 0) { return SM90::GMMA::MMA_64x64x32_F16E4M3E4M3_RS_TN{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 56 == 0) { + return SM90::GMMA::MMA_64x56x32_F16E4M3E4M3_RS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 48 == 0) { return SM90::GMMA::MMA_64x48x32_F16E4M3E4M3_RS_TN{}; } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 40 == 0) { + return SM90::GMMA::MMA_64x40x32_F16E4M3E4M3_RS_TN{}; + } #endif else if constexpr (Tile_N % 32 == 0) { return SM90::GMMA::MMA_64x32x32_F16E4M3E4M3_RS_TN{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 24 == 0) { + return SM90::GMMA::MMA_64x24x32_F16E4M3E4M3_RS_TN{}; + } +#endif else if constexpr (Tile_N % 16 == 0) { return SM90::GMMA::MMA_64x16x32_F16E4M3E4M3_RS_TN{}; } @@ -3197,66 +5177,141 @@ rs_op_selector() if constexpr (Tile_N % 256 == 0) { return SM90::GMMA::MMA_64x256x32_F16E4M3E5M2_RS_TN{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 248 == 0) { + return SM90::GMMA::MMA_64x248x32_F16E4M3E5M2_RS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 240 == 0) { return SM90::GMMA::MMA_64x240x32_F16E4M3E5M2_RS_TN{}; } #endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 232 == 0) { + return SM90::GMMA::MMA_64x232x32_F16E4M3E5M2_RS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 224 == 0) { return SM90::GMMA::MMA_64x224x32_F16E4M3E5M2_RS_TN{}; } #endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 216 == 0) { + return SM90::GMMA::MMA_64x216x32_F16E4M3E5M2_RS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 208 == 0) { return SM90::GMMA::MMA_64x208x32_F16E4M3E5M2_RS_TN{}; } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 200 == 0) { + return SM90::GMMA::MMA_64x200x32_F16E4M3E5M2_RS_TN{}; + } #endif else if constexpr (Tile_N % 192 == 0) { return SM90::GMMA::MMA_64x192x32_F16E4M3E5M2_RS_TN{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 184 == 0) { + return SM90::GMMA::MMA_64x184x32_F16E4M3E5M2_RS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 176 == 0) { return SM90::GMMA::MMA_64x176x32_F16E4M3E5M2_RS_TN{}; } #endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 168 == 0) { + return SM90::GMMA::MMA_64x168x32_F16E4M3E5M2_RS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 160 == 0) { return SM90::GMMA::MMA_64x160x32_F16E4M3E5M2_RS_TN{}; } #endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 152 == 0) { + return SM90::GMMA::MMA_64x152x32_F16E4M3E5M2_RS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 144 == 0) { return SM90::GMMA::MMA_64x144x32_F16E4M3E5M2_RS_TN{}; } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 136 == 0) { + return SM90::GMMA::MMA_64x136x32_F16E4M3E5M2_RS_TN{}; + } #endif else if constexpr (Tile_N % 128 == 0) { return SM90::GMMA::MMA_64x128x32_F16E4M3E5M2_RS_TN{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 120 == 0) { + return SM90::GMMA::MMA_64x120x32_F16E4M3E5M2_RS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 112 == 0) { return SM90::GMMA::MMA_64x112x32_F16E4M3E5M2_RS_TN{}; } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 104 == 0) { + return SM90::GMMA::MMA_64x104x32_F16E4M3E5M2_RS_TN{}; + } #endif else if constexpr (Tile_N % 96 == 0) { return SM90::GMMA::MMA_64x96x32_F16E4M3E5M2_RS_TN{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 88 == 0) { + return SM90::GMMA::MMA_64x88x32_F16E4M3E5M2_RS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 80 == 0) { return SM90::GMMA::MMA_64x80x32_F16E4M3E5M2_RS_TN{}; } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 72 == 0) { + return SM90::GMMA::MMA_64x72x32_F16E4M3E5M2_RS_TN{}; + } #endif else if constexpr (Tile_N % 64 == 0) { return SM90::GMMA::MMA_64x64x32_F16E4M3E5M2_RS_TN{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 56 == 0) { + return SM90::GMMA::MMA_64x56x32_F16E4M3E5M2_RS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 48 == 0) { return SM90::GMMA::MMA_64x48x32_F16E4M3E5M2_RS_TN{}; } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 40 == 0) { + return SM90::GMMA::MMA_64x40x32_F16E4M3E5M2_RS_TN{}; + } #endif else if constexpr (Tile_N % 32 == 0) { return SM90::GMMA::MMA_64x32x32_F16E4M3E5M2_RS_TN{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 24 == 0) { + return SM90::GMMA::MMA_64x24x32_F16E4M3E5M2_RS_TN{}; + } +#endif else if constexpr (Tile_N % 16 == 0) { return SM90::GMMA::MMA_64x16x32_F16E4M3E5M2_RS_TN{}; } @@ -3277,66 +5332,141 @@ rs_op_selector() if constexpr (Tile_N % 256 == 0) { return SM90::GMMA::MMA_64x256x32_F16E5M2E4M3_RS_TN{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 248 == 0) { + return SM90::GMMA::MMA_64x248x32_F16E5M2E4M3_RS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 240 == 0) { return SM90::GMMA::MMA_64x240x32_F16E5M2E4M3_RS_TN{}; } #endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 232 == 0) { + return SM90::GMMA::MMA_64x232x32_F16E5M2E4M3_RS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 224 == 0) { return SM90::GMMA::MMA_64x224x32_F16E5M2E4M3_RS_TN{}; } #endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 216 == 0) { + return SM90::GMMA::MMA_64x216x32_F16E5M2E4M3_RS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 208 == 0) { return SM90::GMMA::MMA_64x208x32_F16E5M2E4M3_RS_TN{}; } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 200 == 0) { + return SM90::GMMA::MMA_64x200x32_F16E5M2E4M3_RS_TN{}; + } #endif else if constexpr (Tile_N % 192 == 0) { return SM90::GMMA::MMA_64x192x32_F16E5M2E4M3_RS_TN{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 184 == 0) { + return SM90::GMMA::MMA_64x184x32_F16E5M2E4M3_RS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 176 == 0) { return SM90::GMMA::MMA_64x176x32_F16E5M2E4M3_RS_TN{}; } #endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 168 == 0) { + return SM90::GMMA::MMA_64x168x32_F16E5M2E4M3_RS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 160 == 0) { return SM90::GMMA::MMA_64x160x32_F16E5M2E4M3_RS_TN{}; } #endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 152 == 0) { + return SM90::GMMA::MMA_64x152x32_F16E5M2E4M3_RS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 144 == 0) { return SM90::GMMA::MMA_64x144x32_F16E5M2E4M3_RS_TN{}; } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 136 == 0) { + return SM90::GMMA::MMA_64x136x32_F16E5M2E4M3_RS_TN{}; + } #endif else if constexpr (Tile_N % 128 == 0) { return SM90::GMMA::MMA_64x128x32_F16E5M2E4M3_RS_TN{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 120 == 0) { + return SM90::GMMA::MMA_64x120x32_F16E5M2E4M3_RS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 112 == 0) { return SM90::GMMA::MMA_64x112x32_F16E5M2E4M3_RS_TN{}; } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 104 == 0) { + return SM90::GMMA::MMA_64x104x32_F16E5M2E4M3_RS_TN{}; + } #endif else if constexpr (Tile_N % 96 == 0) { return SM90::GMMA::MMA_64x96x32_F16E5M2E4M3_RS_TN{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 88 == 0) { + return SM90::GMMA::MMA_64x88x32_F16E5M2E4M3_RS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 80 == 0) { return SM90::GMMA::MMA_64x80x32_F16E5M2E4M3_RS_TN{}; } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 72 == 0) { + return SM90::GMMA::MMA_64x72x32_F16E5M2E4M3_RS_TN{}; + } #endif else if constexpr (Tile_N % 64 == 0) { return SM90::GMMA::MMA_64x64x32_F16E5M2E4M3_RS_TN{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 56 == 0) { + return SM90::GMMA::MMA_64x56x32_F16E5M2E4M3_RS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 48 == 0) { return SM90::GMMA::MMA_64x48x32_F16E5M2E4M3_RS_TN{}; } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 40 == 0) { + return SM90::GMMA::MMA_64x40x32_F16E5M2E4M3_RS_TN{}; + } #endif else if constexpr (Tile_N % 32 == 0) { return SM90::GMMA::MMA_64x32x32_F16E5M2E4M3_RS_TN{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 24 == 0) { + return SM90::GMMA::MMA_64x24x32_F16E5M2E4M3_RS_TN{}; + } +#endif else if constexpr (Tile_N % 16 == 0) { return SM90::GMMA::MMA_64x16x32_F16E5M2E4M3_RS_TN{}; } @@ -3357,66 +5487,141 @@ rs_op_selector() if constexpr (Tile_N % 256 == 0) { return SM90::GMMA::MMA_64x256x32_F16E5M2E5M2_RS_TN{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 248 == 0) { + return SM90::GMMA::MMA_64x248x32_F16E5M2E5M2_RS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 240 == 0) { return SM90::GMMA::MMA_64x240x32_F16E5M2E5M2_RS_TN{}; } #endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 232 == 0) { + return SM90::GMMA::MMA_64x232x32_F16E5M2E5M2_RS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 224 == 0) { return SM90::GMMA::MMA_64x224x32_F16E5M2E5M2_RS_TN{}; } #endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 216 == 0) { + return SM90::GMMA::MMA_64x216x32_F16E5M2E5M2_RS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 208 == 0) { return SM90::GMMA::MMA_64x208x32_F16E5M2E5M2_RS_TN{}; } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 200 == 0) { + return SM90::GMMA::MMA_64x200x32_F16E5M2E5M2_RS_TN{}; + } #endif else if constexpr (Tile_N % 192 == 0) { return SM90::GMMA::MMA_64x192x32_F16E5M2E5M2_RS_TN{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 184 == 0) { + return SM90::GMMA::MMA_64x184x32_F16E5M2E5M2_RS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 176 == 0) { return SM90::GMMA::MMA_64x176x32_F16E5M2E5M2_RS_TN{}; } #endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 168 == 0) { + return SM90::GMMA::MMA_64x168x32_F16E5M2E5M2_RS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 160 == 0) { return SM90::GMMA::MMA_64x160x32_F16E5M2E5M2_RS_TN{}; } #endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 152 == 0) { + return SM90::GMMA::MMA_64x152x32_F16E5M2E5M2_RS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 144 == 0) { return SM90::GMMA::MMA_64x144x32_F16E5M2E5M2_RS_TN{}; } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 136 == 0) { + return SM90::GMMA::MMA_64x136x32_F16E5M2E5M2_RS_TN{}; + } #endif else if constexpr (Tile_N % 128 == 0) { return SM90::GMMA::MMA_64x128x32_F16E5M2E5M2_RS_TN{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 120 == 0) { + return SM90::GMMA::MMA_64x120x32_F16E5M2E5M2_RS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 112 == 0) { return SM90::GMMA::MMA_64x112x32_F16E5M2E5M2_RS_TN{}; } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 104 == 0) { + return SM90::GMMA::MMA_64x104x32_F16E5M2E5M2_RS_TN{}; + } #endif else if constexpr (Tile_N % 96 == 0) { return SM90::GMMA::MMA_64x96x32_F16E5M2E5M2_RS_TN{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 88 == 0) { + return SM90::GMMA::MMA_64x88x32_F16E5M2E5M2_RS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 80 == 0) { return SM90::GMMA::MMA_64x80x32_F16E5M2E5M2_RS_TN{}; } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 72 == 0) { + return SM90::GMMA::MMA_64x72x32_F16E5M2E5M2_RS_TN{}; + } #endif else if constexpr (Tile_N % 64 == 0) { return SM90::GMMA::MMA_64x64x32_F16E5M2E5M2_RS_TN{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 56 == 0) { + return SM90::GMMA::MMA_64x56x32_F16E5M2E5M2_RS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 48 == 0) { return SM90::GMMA::MMA_64x48x32_F16E5M2E5M2_RS_TN{}; } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 40 == 0) { + return SM90::GMMA::MMA_64x40x32_F16E5M2E5M2_RS_TN{}; + } #endif else if constexpr (Tile_N % 32 == 0) { return SM90::GMMA::MMA_64x32x32_F16E5M2E5M2_RS_TN{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 24 == 0) { + return SM90::GMMA::MMA_64x24x32_F16E5M2E5M2_RS_TN{}; + } +#endif else if constexpr (Tile_N % 16 == 0) { return SM90::GMMA::MMA_64x16x32_F16E5M2E5M2_RS_TN{}; } @@ -3443,66 +5648,141 @@ rs_op_selector() if constexpr (Tile_N % 256 == 0) { return SM90::GMMA::MMA_64x256x16_F32F16F16_RS{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 248 == 0) { + return SM90::GMMA::MMA_64x248x16_F32F16F16_RS{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 240 == 0) { return SM90::GMMA::MMA_64x240x16_F32F16F16_RS{}; } #endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 232 == 0) { + return SM90::GMMA::MMA_64x232x16_F32F16F16_RS{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 224 == 0) { return SM90::GMMA::MMA_64x224x16_F32F16F16_RS{}; } #endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 216 == 0) { + return SM90::GMMA::MMA_64x216x16_F32F16F16_RS{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 208 == 0) { return SM90::GMMA::MMA_64x208x16_F32F16F16_RS{}; } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 200 == 0) { + return SM90::GMMA::MMA_64x200x16_F32F16F16_RS{}; + } #endif else if constexpr (Tile_N % 192 == 0) { return SM90::GMMA::MMA_64x192x16_F32F16F16_RS{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 184 == 0) { + return SM90::GMMA::MMA_64x184x16_F32F16F16_RS{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 176 == 0) { return SM90::GMMA::MMA_64x176x16_F32F16F16_RS{}; } #endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 168 == 0) { + return SM90::GMMA::MMA_64x168x16_F32F16F16_RS{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 160 == 0) { return SM90::GMMA::MMA_64x160x16_F32F16F16_RS{}; } #endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 152 == 0) { + return SM90::GMMA::MMA_64x152x16_F32F16F16_RS{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 144 == 0) { return SM90::GMMA::MMA_64x144x16_F32F16F16_RS{}; } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 136 == 0) { + return SM90::GMMA::MMA_64x136x16_F32F16F16_RS{}; + } #endif else if constexpr (Tile_N % 128 == 0) { return SM90::GMMA::MMA_64x128x16_F32F16F16_RS{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 120 == 0) { + return SM90::GMMA::MMA_64x120x16_F32F16F16_RS{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 112 == 0) { return SM90::GMMA::MMA_64x112x16_F32F16F16_RS{}; } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 104 == 0) { + return SM90::GMMA::MMA_64x104x16_F32F16F16_RS{}; + } #endif else if constexpr (Tile_N % 96 == 0) { return SM90::GMMA::MMA_64x96x16_F32F16F16_RS{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 88 == 0) { + return SM90::GMMA::MMA_64x88x16_F32F16F16_RS{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 80 == 0) { return SM90::GMMA::MMA_64x80x16_F32F16F16_RS{}; } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 72 == 0) { + return SM90::GMMA::MMA_64x72x16_F32F16F16_RS{}; + } #endif else if constexpr (Tile_N % 64 == 0) { return SM90::GMMA::MMA_64x64x16_F32F16F16_RS{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 56 == 0) { + return SM90::GMMA::MMA_64x56x16_F32F16F16_RS{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 48 == 0) { return SM90::GMMA::MMA_64x48x16_F32F16F16_RS{}; } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 40 == 0) { + return SM90::GMMA::MMA_64x40x16_F32F16F16_RS{}; + } #endif else if constexpr (Tile_N % 32 == 0) { return SM90::GMMA::MMA_64x32x16_F32F16F16_RS{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 24 == 0) { + return SM90::GMMA::MMA_64x24x16_F32F16F16_RS{}; + } +#endif else if constexpr (Tile_N % 16 == 0) { return SM90::GMMA::MMA_64x16x16_F32F16F16_RS{}; } @@ -3521,66 +5801,141 @@ rs_op_selector() if constexpr (Tile_N % 256 == 0) { return SM90::GMMA::MMA_64x256x16_F32BF16BF16_RS{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 248 == 0) { + return SM90::GMMA::MMA_64x248x16_F32BF16BF16_RS{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 240 == 0) { return SM90::GMMA::MMA_64x240x16_F32BF16BF16_RS{}; } #endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 232 == 0) { + return SM90::GMMA::MMA_64x232x16_F32BF16BF16_RS{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 224 == 0) { return SM90::GMMA::MMA_64x224x16_F32BF16BF16_RS{}; } #endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 216 == 0) { + return SM90::GMMA::MMA_64x216x16_F32BF16BF16_RS{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 208 == 0) { return SM90::GMMA::MMA_64x208x16_F32BF16BF16_RS{}; } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 200 == 0) { + return SM90::GMMA::MMA_64x200x16_F32BF16BF16_RS{}; + } #endif else if constexpr (Tile_N % 192 == 0) { return SM90::GMMA::MMA_64x192x16_F32BF16BF16_RS{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 184 == 0) { + return SM90::GMMA::MMA_64x184x16_F32BF16BF16_RS{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 176 == 0) { return SM90::GMMA::MMA_64x176x16_F32BF16BF16_RS{}; } #endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 168 == 0) { + return SM90::GMMA::MMA_64x168x16_F32BF16BF16_RS{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 160 == 0) { return SM90::GMMA::MMA_64x160x16_F32BF16BF16_RS{}; } #endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 152 == 0) { + return SM90::GMMA::MMA_64x152x16_F32BF16BF16_RS{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 144 == 0) { return SM90::GMMA::MMA_64x144x16_F32BF16BF16_RS{}; } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 136 == 0) { + return SM90::GMMA::MMA_64x136x16_F32BF16BF16_RS{}; + } #endif else if constexpr (Tile_N % 128 == 0) { return SM90::GMMA::MMA_64x128x16_F32BF16BF16_RS{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 120 == 0) { + return SM90::GMMA::MMA_64x120x16_F32BF16BF16_RS{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 112 == 0) { return SM90::GMMA::MMA_64x112x16_F32BF16BF16_RS{}; } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 104 == 0) { + return SM90::GMMA::MMA_64x104x16_F32BF16BF16_RS{}; + } #endif else if constexpr (Tile_N % 96 == 0) { return SM90::GMMA::MMA_64x96x16_F32BF16BF16_RS{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 88 == 0) { + return SM90::GMMA::MMA_64x88x16_F32BF16BF16_RS{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 80 == 0) { return SM90::GMMA::MMA_64x80x16_F32BF16BF16_RS{}; } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 72 == 0) { + return SM90::GMMA::MMA_64x72x16_F32BF16BF16_RS{}; + } #endif else if constexpr (Tile_N % 64 == 0) { return SM90::GMMA::MMA_64x64x16_F32BF16BF16_RS{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 56 == 0) { + return SM90::GMMA::MMA_64x56x16_F32BF16BF16_RS{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 48 == 0) { return SM90::GMMA::MMA_64x48x16_F32BF16BF16_RS{}; } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 40 == 0) { + return SM90::GMMA::MMA_64x40x16_F32BF16BF16_RS{}; + } #endif else if constexpr (Tile_N % 32 == 0) { return SM90::GMMA::MMA_64x32x16_F32BF16BF16_RS{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 24 == 0) { + return SM90::GMMA::MMA_64x24x16_F32BF16BF16_RS{}; + } +#endif else if constexpr (Tile_N % 16 == 0) { return SM90::GMMA::MMA_64x16x16_F32BF16BF16_RS{}; } @@ -3601,66 +5956,141 @@ rs_op_selector() if constexpr (Tile_N % 256 == 0) { return SM90::GMMA::MMA_64x256x8_F32TF32TF32_RS_TN{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 248 == 0) { + return SM90::GMMA::MMA_64x248x8_F32TF32TF32_RS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 240 == 0) { return SM90::GMMA::MMA_64x240x8_F32TF32TF32_RS_TN{}; } #endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 232 == 0) { + return SM90::GMMA::MMA_64x232x8_F32TF32TF32_RS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 224 == 0) { return SM90::GMMA::MMA_64x224x8_F32TF32TF32_RS_TN{}; } #endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 216 == 0) { + return SM90::GMMA::MMA_64x216x8_F32TF32TF32_RS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 208 == 0) { return SM90::GMMA::MMA_64x208x8_F32TF32TF32_RS_TN{}; } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 200 == 0) { + return SM90::GMMA::MMA_64x200x8_F32TF32TF32_RS_TN{}; + } #endif else if constexpr (Tile_N % 192 == 0) { return SM90::GMMA::MMA_64x192x8_F32TF32TF32_RS_TN{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 184 == 0) { + return SM90::GMMA::MMA_64x184x8_F32TF32TF32_RS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 176 == 0) { return SM90::GMMA::MMA_64x176x8_F32TF32TF32_RS_TN{}; } #endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 168 == 0) { + return SM90::GMMA::MMA_64x168x8_F32TF32TF32_RS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 160 == 0) { return SM90::GMMA::MMA_64x160x8_F32TF32TF32_RS_TN{}; } #endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 152 == 0) { + return SM90::GMMA::MMA_64x152x8_F32TF32TF32_RS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 144 == 0) { return SM90::GMMA::MMA_64x144x8_F32TF32TF32_RS_TN{}; } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 136 == 0) { + return SM90::GMMA::MMA_64x136x8_F32TF32TF32_RS_TN{}; + } #endif else if constexpr (Tile_N % 128 == 0) { return SM90::GMMA::MMA_64x128x8_F32TF32TF32_RS_TN{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 120 == 0) { + return SM90::GMMA::MMA_64x120x8_F32TF32TF32_RS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 112 == 0) { return SM90::GMMA::MMA_64x112x8_F32TF32TF32_RS_TN{}; } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 104 == 0) { + return SM90::GMMA::MMA_64x104x8_F32TF32TF32_RS_TN{}; + } #endif else if constexpr (Tile_N % 96 == 0) { return SM90::GMMA::MMA_64x96x8_F32TF32TF32_RS_TN{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 88 == 0) { + return SM90::GMMA::MMA_64x88x8_F32TF32TF32_RS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 80 == 0) { return SM90::GMMA::MMA_64x80x8_F32TF32TF32_RS_TN{}; } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 72 == 0) { + return SM90::GMMA::MMA_64x72x8_F32TF32TF32_RS_TN{}; + } #endif else if constexpr (Tile_N % 64 == 0) { return SM90::GMMA::MMA_64x64x8_F32TF32TF32_RS_TN{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 56 == 0) { + return SM90::GMMA::MMA_64x56x8_F32TF32TF32_RS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 48 == 0) { return SM90::GMMA::MMA_64x48x8_F32TF32TF32_RS_TN{}; } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 40 == 0) { + return SM90::GMMA::MMA_64x40x8_F32TF32TF32_RS_TN{}; + } #endif else if constexpr (Tile_N % 32 == 0) { return SM90::GMMA::MMA_64x32x8_F32TF32TF32_RS_TN{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 24 == 0) { + return SM90::GMMA::MMA_64x24x8_F32TF32TF32_RS_TN{}; + } +#endif else if constexpr (Tile_N % 16 == 0) { return SM90::GMMA::MMA_64x16x8_F32TF32TF32_RS_TN{}; } @@ -3681,66 +6111,141 @@ rs_op_selector() if constexpr (Tile_N % 256 == 0) { return SM90::GMMA::MMA_64x256x32_F32E4M3E4M3_RS_TN{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 248 == 0) { + return SM90::GMMA::MMA_64x248x32_F32E4M3E4M3_RS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 240 == 0) { return SM90::GMMA::MMA_64x240x32_F32E4M3E4M3_RS_TN{}; } #endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 232 == 0) { + return SM90::GMMA::MMA_64x232x32_F32E4M3E4M3_RS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 224 == 0) { return SM90::GMMA::MMA_64x224x32_F32E4M3E4M3_RS_TN{}; } #endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 216 == 0) { + return SM90::GMMA::MMA_64x216x32_F32E4M3E4M3_RS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 208 == 0) { return SM90::GMMA::MMA_64x208x32_F32E4M3E4M3_RS_TN{}; } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 200 == 0) { + return SM90::GMMA::MMA_64x200x32_F32E4M3E4M3_RS_TN{}; + } #endif else if constexpr (Tile_N % 192 == 0) { return SM90::GMMA::MMA_64x192x32_F32E4M3E4M3_RS_TN{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 184 == 0) { + return SM90::GMMA::MMA_64x184x32_F32E4M3E4M3_RS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 176 == 0) { return SM90::GMMA::MMA_64x176x32_F32E4M3E4M3_RS_TN{}; } #endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 168 == 0) { + return SM90::GMMA::MMA_64x168x32_F32E4M3E4M3_RS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 160 == 0) { return SM90::GMMA::MMA_64x160x32_F32E4M3E4M3_RS_TN{}; } #endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 152 == 0) { + return SM90::GMMA::MMA_64x152x32_F32E4M3E4M3_RS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 144 == 0) { return SM90::GMMA::MMA_64x144x32_F32E4M3E4M3_RS_TN{}; } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 136 == 0) { + return SM90::GMMA::MMA_64x136x32_F32E4M3E4M3_RS_TN{}; + } #endif else if constexpr (Tile_N % 128 == 0) { return SM90::GMMA::MMA_64x128x32_F32E4M3E4M3_RS_TN{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 120 == 0) { + return SM90::GMMA::MMA_64x120x32_F32E4M3E4M3_RS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 112 == 0) { return SM90::GMMA::MMA_64x112x32_F32E4M3E4M3_RS_TN{}; } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 104 == 0) { + return SM90::GMMA::MMA_64x104x32_F32E4M3E4M3_RS_TN{}; + } #endif else if constexpr (Tile_N % 96 == 0) { return SM90::GMMA::MMA_64x96x32_F32E4M3E4M3_RS_TN{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 88 == 0) { + return SM90::GMMA::MMA_64x88x32_F32E4M3E4M3_RS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 80 == 0) { return SM90::GMMA::MMA_64x80x32_F32E4M3E4M3_RS_TN{}; } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 72 == 0) { + return SM90::GMMA::MMA_64x72x32_F32E4M3E4M3_RS_TN{}; + } #endif else if constexpr (Tile_N % 64 == 0) { return SM90::GMMA::MMA_64x64x32_F32E4M3E4M3_RS_TN{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 56 == 0) { + return SM90::GMMA::MMA_64x56x32_F32E4M3E4M3_RS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 48 == 0) { return SM90::GMMA::MMA_64x48x32_F32E4M3E4M3_RS_TN{}; } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 40 == 0) { + return SM90::GMMA::MMA_64x40x32_F32E4M3E4M3_RS_TN{}; + } #endif else if constexpr (Tile_N % 32 == 0) { return SM90::GMMA::MMA_64x32x32_F32E4M3E4M3_RS_TN{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 24 == 0) { + return SM90::GMMA::MMA_64x24x32_F32E4M3E4M3_RS_TN{}; + } +#endif else if constexpr (Tile_N % 16 == 0) { return SM90::GMMA::MMA_64x16x32_F32E4M3E4M3_RS_TN{}; } @@ -3761,66 +6266,141 @@ rs_op_selector() if constexpr (Tile_N % 256 == 0) { return SM90::GMMA::MMA_64x256x32_F32E4M3E5M2_RS_TN{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 248 == 0) { + return SM90::GMMA::MMA_64x248x32_F32E4M3E5M2_RS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 240 == 0) { return SM90::GMMA::MMA_64x240x32_F32E4M3E5M2_RS_TN{}; } #endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 232 == 0) { + return SM90::GMMA::MMA_64x232x32_F32E4M3E5M2_RS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 224 == 0) { return SM90::GMMA::MMA_64x224x32_F32E4M3E5M2_RS_TN{}; } #endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 216 == 0) { + return SM90::GMMA::MMA_64x216x32_F32E4M3E5M2_RS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 208 == 0) { return SM90::GMMA::MMA_64x208x32_F32E4M3E5M2_RS_TN{}; } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 200 == 0) { + return SM90::GMMA::MMA_64x200x32_F32E4M3E5M2_RS_TN{}; + } #endif else if constexpr (Tile_N % 192 == 0) { return SM90::GMMA::MMA_64x192x32_F32E4M3E5M2_RS_TN{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 184 == 0) { + return SM90::GMMA::MMA_64x184x32_F32E4M3E5M2_RS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 176 == 0) { return SM90::GMMA::MMA_64x176x32_F32E4M3E5M2_RS_TN{}; } #endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 168 == 0) { + return SM90::GMMA::MMA_64x168x32_F32E4M3E5M2_RS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 160 == 0) { return SM90::GMMA::MMA_64x160x32_F32E4M3E5M2_RS_TN{}; } #endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 152 == 0) { + return SM90::GMMA::MMA_64x152x32_F32E4M3E5M2_RS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 144 == 0) { return SM90::GMMA::MMA_64x144x32_F32E4M3E5M2_RS_TN{}; } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 136 == 0) { + return SM90::GMMA::MMA_64x136x32_F32E4M3E5M2_RS_TN{}; + } #endif else if constexpr (Tile_N % 128 == 0) { return SM90::GMMA::MMA_64x128x32_F32E4M3E5M2_RS_TN{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 120 == 0) { + return SM90::GMMA::MMA_64x120x32_F32E4M3E5M2_RS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 112 == 0) { return SM90::GMMA::MMA_64x112x32_F32E4M3E5M2_RS_TN{}; } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 104 == 0) { + return SM90::GMMA::MMA_64x104x32_F32E4M3E5M2_RS_TN{}; + } #endif else if constexpr (Tile_N % 96 == 0) { return SM90::GMMA::MMA_64x96x32_F32E4M3E5M2_RS_TN{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 88 == 0) { + return SM90::GMMA::MMA_64x88x32_F32E4M3E5M2_RS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 80 == 0) { return SM90::GMMA::MMA_64x80x32_F32E4M3E5M2_RS_TN{}; } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 72 == 0) { + return SM90::GMMA::MMA_64x72x32_F32E4M3E5M2_RS_TN{}; + } #endif else if constexpr (Tile_N % 64 == 0) { return SM90::GMMA::MMA_64x64x32_F32E4M3E5M2_RS_TN{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 56 == 0) { + return SM90::GMMA::MMA_64x56x32_F32E4M3E5M2_RS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 48 == 0) { return SM90::GMMA::MMA_64x48x32_F32E4M3E5M2_RS_TN{}; } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 40 == 0) { + return SM90::GMMA::MMA_64x40x32_F32E4M3E5M2_RS_TN{}; + } #endif else if constexpr (Tile_N % 32 == 0) { return SM90::GMMA::MMA_64x32x32_F32E4M3E5M2_RS_TN{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 24 == 0) { + return SM90::GMMA::MMA_64x24x32_F32E4M3E5M2_RS_TN{}; + } +#endif else if constexpr (Tile_N % 16 == 0) { return SM90::GMMA::MMA_64x16x32_F32E4M3E5M2_RS_TN{}; } @@ -3841,66 +6421,141 @@ rs_op_selector() if constexpr (Tile_N % 256 == 0) { return SM90::GMMA::MMA_64x256x32_F32E5M2E4M3_RS_TN{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 248 == 0) { + return SM90::GMMA::MMA_64x248x32_F32E5M2E4M3_RS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 240 == 0) { return SM90::GMMA::MMA_64x240x32_F32E5M2E4M3_RS_TN{}; } #endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 232 == 0) { + return SM90::GMMA::MMA_64x232x32_F32E5M2E4M3_RS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 224 == 0) { return SM90::GMMA::MMA_64x224x32_F32E5M2E4M3_RS_TN{}; } #endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 216 == 0) { + return SM90::GMMA::MMA_64x216x32_F32E5M2E4M3_RS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 208 == 0) { return SM90::GMMA::MMA_64x208x32_F32E5M2E4M3_RS_TN{}; } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 200 == 0) { + return SM90::GMMA::MMA_64x200x32_F32E5M2E4M3_RS_TN{}; + } #endif else if constexpr (Tile_N % 192 == 0) { return SM90::GMMA::MMA_64x192x32_F32E5M2E4M3_RS_TN{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 184 == 0) { + return SM90::GMMA::MMA_64x184x32_F32E5M2E4M3_RS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 176 == 0) { return SM90::GMMA::MMA_64x176x32_F32E5M2E4M3_RS_TN{}; } #endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 168 == 0) { + return SM90::GMMA::MMA_64x168x32_F32E5M2E4M3_RS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 160 == 0) { return SM90::GMMA::MMA_64x160x32_F32E5M2E4M3_RS_TN{}; } #endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 152 == 0) { + return SM90::GMMA::MMA_64x152x32_F32E5M2E4M3_RS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 144 == 0) { return SM90::GMMA::MMA_64x144x32_F32E5M2E4M3_RS_TN{}; } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 136 == 0) { + return SM90::GMMA::MMA_64x136x32_F32E5M2E4M3_RS_TN{}; + } #endif else if constexpr (Tile_N % 128 == 0) { return SM90::GMMA::MMA_64x128x32_F32E5M2E4M3_RS_TN{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 120 == 0) { + return SM90::GMMA::MMA_64x120x32_F32E5M2E4M3_RS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 112 == 0) { return SM90::GMMA::MMA_64x112x32_F32E5M2E4M3_RS_TN{}; } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 104 == 0) { + return SM90::GMMA::MMA_64x104x32_F32E5M2E4M3_RS_TN{}; + } #endif else if constexpr (Tile_N % 96 == 0) { return SM90::GMMA::MMA_64x96x32_F32E5M2E4M3_RS_TN{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 88 == 0) { + return SM90::GMMA::MMA_64x88x32_F32E5M2E4M3_RS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 80 == 0) { return SM90::GMMA::MMA_64x80x32_F32E5M2E4M3_RS_TN{}; } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 72 == 0) { + return SM90::GMMA::MMA_64x72x32_F32E5M2E4M3_RS_TN{}; + } #endif else if constexpr (Tile_N % 64 == 0) { return SM90::GMMA::MMA_64x64x32_F32E5M2E4M3_RS_TN{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 56 == 0) { + return SM90::GMMA::MMA_64x56x32_F32E5M2E4M3_RS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 48 == 0) { return SM90::GMMA::MMA_64x48x32_F32E5M2E4M3_RS_TN{}; } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 40 == 0) { + return SM90::GMMA::MMA_64x40x32_F32E5M2E4M3_RS_TN{}; + } #endif else if constexpr (Tile_N % 32 == 0) { return SM90::GMMA::MMA_64x32x32_F32E5M2E4M3_RS_TN{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 24 == 0) { + return SM90::GMMA::MMA_64x24x32_F32E5M2E4M3_RS_TN{}; + } +#endif else if constexpr (Tile_N % 16 == 0) { return SM90::GMMA::MMA_64x16x32_F32E5M2E4M3_RS_TN{}; } @@ -3921,66 +6576,141 @@ rs_op_selector() if constexpr (Tile_N % 256 == 0) { return SM90::GMMA::MMA_64x256x32_F32E5M2E5M2_RS_TN{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 248 == 0) { + return SM90::GMMA::MMA_64x248x32_F32E5M2E5M2_RS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 240 == 0) { return SM90::GMMA::MMA_64x240x32_F32E5M2E5M2_RS_TN{}; } #endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 232 == 0) { + return SM90::GMMA::MMA_64x232x32_F32E5M2E5M2_RS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 224 == 0) { return SM90::GMMA::MMA_64x224x32_F32E5M2E5M2_RS_TN{}; } #endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 216 == 0) { + return SM90::GMMA::MMA_64x216x32_F32E5M2E5M2_RS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 208 == 0) { return SM90::GMMA::MMA_64x208x32_F32E5M2E5M2_RS_TN{}; } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 200 == 0) { + return SM90::GMMA::MMA_64x200x32_F32E5M2E5M2_RS_TN{}; + } #endif else if constexpr (Tile_N % 192 == 0) { return SM90::GMMA::MMA_64x192x32_F32E5M2E5M2_RS_TN{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 184 == 0) { + return SM90::GMMA::MMA_64x184x32_F32E5M2E5M2_RS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 176 == 0) { return SM90::GMMA::MMA_64x176x32_F32E5M2E5M2_RS_TN{}; } #endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 168 == 0) { + return SM90::GMMA::MMA_64x168x32_F32E5M2E5M2_RS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 160 == 0) { return SM90::GMMA::MMA_64x160x32_F32E5M2E5M2_RS_TN{}; } #endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 152 == 0) { + return SM90::GMMA::MMA_64x152x32_F32E5M2E5M2_RS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 144 == 0) { return SM90::GMMA::MMA_64x144x32_F32E5M2E5M2_RS_TN{}; } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 136 == 0) { + return SM90::GMMA::MMA_64x136x32_F32E5M2E5M2_RS_TN{}; + } #endif else if constexpr (Tile_N % 128 == 0) { return SM90::GMMA::MMA_64x128x32_F32E5M2E5M2_RS_TN{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 120 == 0) { + return SM90::GMMA::MMA_64x120x32_F32E5M2E5M2_RS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 112 == 0) { return SM90::GMMA::MMA_64x112x32_F32E5M2E5M2_RS_TN{}; } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 104 == 0) { + return SM90::GMMA::MMA_64x104x32_F32E5M2E5M2_RS_TN{}; + } #endif else if constexpr (Tile_N % 96 == 0) { return SM90::GMMA::MMA_64x96x32_F32E5M2E5M2_RS_TN{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 88 == 0) { + return SM90::GMMA::MMA_64x88x32_F32E5M2E5M2_RS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 80 == 0) { return SM90::GMMA::MMA_64x80x32_F32E5M2E5M2_RS_TN{}; } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 72 == 0) { + return SM90::GMMA::MMA_64x72x32_F32E5M2E5M2_RS_TN{}; + } #endif else if constexpr (Tile_N % 64 == 0) { return SM90::GMMA::MMA_64x64x32_F32E5M2E5M2_RS_TN{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 56 == 0) { + return SM90::GMMA::MMA_64x56x32_F32E5M2E5M2_RS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 48 == 0) { return SM90::GMMA::MMA_64x48x32_F32E5M2E5M2_RS_TN{}; } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 40 == 0) { + return SM90::GMMA::MMA_64x40x32_F32E5M2E5M2_RS_TN{}; + } #endif else if constexpr (Tile_N % 32 == 0) { return SM90::GMMA::MMA_64x32x32_F32E5M2E5M2_RS_TN{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 24 == 0) { + return SM90::GMMA::MMA_64x24x32_F32E5M2E5M2_RS_TN{}; + } +#endif else if constexpr (Tile_N % 16 == 0) { return SM90::GMMA::MMA_64x16x32_F32E5M2E5M2_RS_TN{}; } @@ -4069,6 +6799,11 @@ rs_op_selector() else if constexpr (Tile_N % 32 == 0) { return SM90::GMMA::MMA_64x32x32_S32S8S8_RS_TN{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 24 == 0) { + return SM90::GMMA::MMA_64x24x32_S32S8S8_RS_TN{}; + } +#endif else if constexpr (Tile_N % 16 == 0) { return SM90::GMMA::MMA_64x16x32_S32S8S8_RS_TN{}; } @@ -4149,6 +6884,11 @@ rs_op_selector() else if constexpr (Tile_N % 32 == 0) { return SM90::GMMA::MMA_64x32x32_S32S8U8_RS_TN{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 24 == 0) { + return SM90::GMMA::MMA_64x24x32_S32S8U8_RS_TN{}; + } +#endif else if constexpr (Tile_N % 16 == 0) { return SM90::GMMA::MMA_64x16x32_S32S8U8_RS_TN{}; } @@ -4229,6 +6969,11 @@ rs_op_selector() else if constexpr (Tile_N % 32 == 0) { return SM90::GMMA::MMA_64x32x32_S32U8S8_RS_TN{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 24 == 0) { + return SM90::GMMA::MMA_64x24x32_S32U8S8_RS_TN{}; + } +#endif else if constexpr (Tile_N % 16 == 0) { return SM90::GMMA::MMA_64x16x32_S32U8S8_RS_TN{}; } @@ -4309,6 +7054,11 @@ rs_op_selector() else if constexpr (Tile_N % 32 == 0) { return SM90::GMMA::MMA_64x32x32_S32U8U8_RS_TN{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 24 == 0) { + return SM90::GMMA::MMA_64x24x32_S32U8U8_RS_TN{}; + } +#endif else if constexpr (Tile_N % 16 == 0) { return SM90::GMMA::MMA_64x16x32_S32U8U8_RS_TN{}; } @@ -4361,66 +7111,141 @@ rs_op_selector_sparse() if constexpr (Tile_N % 256 == 0) { return SM90::GMMA::SPARSE::GMMA_64x256x32_F16F16F16_RS{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 248 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x248x32_F16F16F16_RS{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 240 == 0) { return SM90::GMMA::SPARSE::GMMA_64x240x32_F16F16F16_RS{}; } #endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 232 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x232x32_F16F16F16_RS{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 224 == 0) { return SM90::GMMA::SPARSE::GMMA_64x224x32_F16F16F16_RS{}; } #endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 216 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x216x32_F16F16F16_RS{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 208 == 0) { return SM90::GMMA::SPARSE::GMMA_64x208x32_F16F16F16_RS{}; } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 200 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x200x32_F16F16F16_RS{}; + } #endif else if constexpr (Tile_N % 192 == 0) { return SM90::GMMA::SPARSE::GMMA_64x192x32_F16F16F16_RS{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 184 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x184x32_F16F16F16_RS{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 176 == 0) { return SM90::GMMA::SPARSE::GMMA_64x176x32_F16F16F16_RS{}; } #endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 168 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x168x32_F16F16F16_RS{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 160 == 0) { return SM90::GMMA::SPARSE::GMMA_64x160x32_F16F16F16_RS{}; } #endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 152 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x152x32_F16F16F16_RS{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 144 == 0) { return SM90::GMMA::SPARSE::GMMA_64x144x32_F16F16F16_RS{}; } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 136 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x136x32_F16F16F16_RS{}; + } #endif else if constexpr (Tile_N % 128 == 0) { return SM90::GMMA::SPARSE::GMMA_64x128x32_F16F16F16_RS{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 120 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x120x32_F16F16F16_RS{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 112 == 0) { return SM90::GMMA::SPARSE::GMMA_64x112x32_F16F16F16_RS{}; } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 104 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x104x32_F16F16F16_RS{}; + } #endif else if constexpr (Tile_N % 96 == 0) { return SM90::GMMA::SPARSE::GMMA_64x96x32_F16F16F16_RS{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 88 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x88x32_F16F16F16_RS{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 80 == 0) { return SM90::GMMA::SPARSE::GMMA_64x80x32_F16F16F16_RS{}; } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 72 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x72x32_F16F16F16_RS{}; + } #endif else if constexpr (Tile_N % 64 == 0) { return SM90::GMMA::SPARSE::GMMA_64x64x32_F16F16F16_RS{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 56 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x56x32_F16F16F16_RS{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 48 == 0) { return SM90::GMMA::SPARSE::GMMA_64x48x32_F16F16F16_RS{}; } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 40 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x40x32_F16F16F16_RS{}; + } #endif else if constexpr (Tile_N % 32 == 0) { return SM90::GMMA::SPARSE::GMMA_64x32x32_F16F16F16_RS{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 24 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x24x32_F16F16F16_RS{}; + } +#endif else if constexpr (Tile_N % 16 == 0) { return SM90::GMMA::SPARSE::GMMA_64x16x32_F16F16F16_RS{}; } @@ -4441,66 +7266,141 @@ rs_op_selector_sparse() if constexpr (Tile_N % 256 == 0) { return SM90::GMMA::SPARSE::GMMA_64x256x64_F16E4M3E4M3_RS_TN{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 248 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x248x64_F16E4M3E4M3_RS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 240 == 0) { return SM90::GMMA::SPARSE::GMMA_64x240x64_F16E4M3E4M3_RS_TN{}; } #endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 232 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x232x64_F16E4M3E4M3_RS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 224 == 0) { return SM90::GMMA::SPARSE::GMMA_64x224x64_F16E4M3E4M3_RS_TN{}; } #endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 216 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x216x64_F16E4M3E4M3_RS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 208 == 0) { return SM90::GMMA::SPARSE::GMMA_64x208x64_F16E4M3E4M3_RS_TN{}; } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 200 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x200x64_F16E4M3E4M3_RS_TN{}; + } #endif else if constexpr (Tile_N % 192 == 0) { return SM90::GMMA::SPARSE::GMMA_64x192x64_F16E4M3E4M3_RS_TN{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 184 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x184x64_F16E4M3E4M3_RS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 176 == 0) { return SM90::GMMA::SPARSE::GMMA_64x176x64_F16E4M3E4M3_RS_TN{}; } #endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 168 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x168x64_F16E4M3E4M3_RS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 160 == 0) { return SM90::GMMA::SPARSE::GMMA_64x160x64_F16E4M3E4M3_RS_TN{}; } #endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 152 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x152x64_F16E4M3E4M3_RS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 144 == 0) { return SM90::GMMA::SPARSE::GMMA_64x144x64_F16E4M3E4M3_RS_TN{}; } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 136 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x136x64_F16E4M3E4M3_RS_TN{}; + } #endif else if constexpr (Tile_N % 128 == 0) { return SM90::GMMA::SPARSE::GMMA_64x128x64_F16E4M3E4M3_RS_TN{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 120 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x120x64_F16E4M3E4M3_RS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 112 == 0) { return SM90::GMMA::SPARSE::GMMA_64x112x64_F16E4M3E4M3_RS_TN{}; } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 104 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x104x64_F16E4M3E4M3_RS_TN{}; + } #endif else if constexpr (Tile_N % 96 == 0) { return SM90::GMMA::SPARSE::GMMA_64x96x64_F16E4M3E4M3_RS_TN{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 88 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x88x64_F16E4M3E4M3_RS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 80 == 0) { return SM90::GMMA::SPARSE::GMMA_64x80x64_F16E4M3E4M3_RS_TN{}; } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 72 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x72x64_F16E4M3E4M3_RS_TN{}; + } #endif else if constexpr (Tile_N % 64 == 0) { return SM90::GMMA::SPARSE::GMMA_64x64x64_F16E4M3E4M3_RS_TN{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 56 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x56x64_F16E4M3E4M3_RS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 48 == 0) { return SM90::GMMA::SPARSE::GMMA_64x48x64_F16E4M3E4M3_RS_TN{}; } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 40 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x40x64_F16E4M3E4M3_RS_TN{}; + } #endif else if constexpr (Tile_N % 32 == 0) { return SM90::GMMA::SPARSE::GMMA_64x32x64_F16E4M3E4M3_RS_TN{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 24 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x24x64_F16E4M3E4M3_RS_TN{}; + } +#endif else if constexpr (Tile_N % 16 == 0) { return SM90::GMMA::SPARSE::GMMA_64x16x64_F16E4M3E4M3_RS_TN{}; } @@ -4521,66 +7421,141 @@ rs_op_selector_sparse() if constexpr (Tile_N % 256 == 0) { return SM90::GMMA::SPARSE::GMMA_64x256x64_F16E4M3E5M2_RS_TN{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 248 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x248x64_F16E4M3E5M2_RS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 240 == 0) { return SM90::GMMA::SPARSE::GMMA_64x240x64_F16E4M3E5M2_RS_TN{}; } #endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 232 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x232x64_F16E4M3E5M2_RS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 224 == 0) { return SM90::GMMA::SPARSE::GMMA_64x224x64_F16E4M3E5M2_RS_TN{}; } #endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 216 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x216x64_F16E4M3E5M2_RS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 208 == 0) { return SM90::GMMA::SPARSE::GMMA_64x208x64_F16E4M3E5M2_RS_TN{}; } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 200 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x200x64_F16E4M3E5M2_RS_TN{}; + } #endif else if constexpr (Tile_N % 192 == 0) { return SM90::GMMA::SPARSE::GMMA_64x192x64_F16E4M3E5M2_RS_TN{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 184 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x184x64_F16E4M3E5M2_RS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 176 == 0) { return SM90::GMMA::SPARSE::GMMA_64x176x64_F16E4M3E5M2_RS_TN{}; } #endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 168 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x168x64_F16E4M3E5M2_RS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 160 == 0) { return SM90::GMMA::SPARSE::GMMA_64x160x64_F16E4M3E5M2_RS_TN{}; } #endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 152 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x152x64_F16E4M3E5M2_RS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 144 == 0) { return SM90::GMMA::SPARSE::GMMA_64x144x64_F16E4M3E5M2_RS_TN{}; } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 136 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x136x64_F16E4M3E5M2_RS_TN{}; + } #endif else if constexpr (Tile_N % 128 == 0) { return SM90::GMMA::SPARSE::GMMA_64x128x64_F16E4M3E5M2_RS_TN{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 120 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x120x64_F16E4M3E5M2_RS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 112 == 0) { return SM90::GMMA::SPARSE::GMMA_64x112x64_F16E4M3E5M2_RS_TN{}; } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 104 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x104x64_F16E4M3E5M2_RS_TN{}; + } #endif else if constexpr (Tile_N % 96 == 0) { return SM90::GMMA::SPARSE::GMMA_64x96x64_F16E4M3E5M2_RS_TN{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 88 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x88x64_F16E4M3E5M2_RS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 80 == 0) { return SM90::GMMA::SPARSE::GMMA_64x80x64_F16E4M3E5M2_RS_TN{}; } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 72 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x72x64_F16E4M3E5M2_RS_TN{}; + } #endif else if constexpr (Tile_N % 64 == 0) { return SM90::GMMA::SPARSE::GMMA_64x64x64_F16E4M3E5M2_RS_TN{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 56 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x56x64_F16E4M3E5M2_RS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 48 == 0) { return SM90::GMMA::SPARSE::GMMA_64x48x64_F16E4M3E5M2_RS_TN{}; } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 40 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x40x64_F16E4M3E5M2_RS_TN{}; + } #endif else if constexpr (Tile_N % 32 == 0) { return SM90::GMMA::SPARSE::GMMA_64x32x64_F16E4M3E5M2_RS_TN{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 24 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x24x64_F16E4M3E5M2_RS_TN{}; + } +#endif else if constexpr (Tile_N % 16 == 0) { return SM90::GMMA::SPARSE::GMMA_64x16x64_F16E4M3E5M2_RS_TN{}; } @@ -4601,66 +7576,141 @@ rs_op_selector_sparse() if constexpr (Tile_N % 256 == 0) { return SM90::GMMA::SPARSE::GMMA_64x256x64_F16E5M2E4M3_RS_TN{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 248 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x248x64_F16E5M2E4M3_RS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 240 == 0) { return SM90::GMMA::SPARSE::GMMA_64x240x64_F16E5M2E4M3_RS_TN{}; } #endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 232 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x232x64_F16E5M2E4M3_RS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 224 == 0) { return SM90::GMMA::SPARSE::GMMA_64x224x64_F16E5M2E4M3_RS_TN{}; } #endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 216 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x216x64_F16E5M2E4M3_RS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 208 == 0) { return SM90::GMMA::SPARSE::GMMA_64x208x64_F16E5M2E4M3_RS_TN{}; } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 200 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x200x64_F16E5M2E4M3_RS_TN{}; + } #endif else if constexpr (Tile_N % 192 == 0) { return SM90::GMMA::SPARSE::GMMA_64x192x64_F16E5M2E4M3_RS_TN{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 184 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x184x64_F16E5M2E4M3_RS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 176 == 0) { return SM90::GMMA::SPARSE::GMMA_64x176x64_F16E5M2E4M3_RS_TN{}; } #endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 168 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x168x64_F16E5M2E4M3_RS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 160 == 0) { return SM90::GMMA::SPARSE::GMMA_64x160x64_F16E5M2E4M3_RS_TN{}; } #endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 152 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x152x64_F16E5M2E4M3_RS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 144 == 0) { return SM90::GMMA::SPARSE::GMMA_64x144x64_F16E5M2E4M3_RS_TN{}; } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 136 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x136x64_F16E5M2E4M3_RS_TN{}; + } #endif else if constexpr (Tile_N % 128 == 0) { return SM90::GMMA::SPARSE::GMMA_64x128x64_F16E5M2E4M3_RS_TN{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 120 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x120x64_F16E5M2E4M3_RS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 112 == 0) { return SM90::GMMA::SPARSE::GMMA_64x112x64_F16E5M2E4M3_RS_TN{}; } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 104 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x104x64_F16E5M2E4M3_RS_TN{}; + } #endif else if constexpr (Tile_N % 96 == 0) { return SM90::GMMA::SPARSE::GMMA_64x96x64_F16E5M2E4M3_RS_TN{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 88 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x88x64_F16E5M2E4M3_RS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 80 == 0) { return SM90::GMMA::SPARSE::GMMA_64x80x64_F16E5M2E4M3_RS_TN{}; } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 72 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x72x64_F16E5M2E4M3_RS_TN{}; + } #endif else if constexpr (Tile_N % 64 == 0) { return SM90::GMMA::SPARSE::GMMA_64x64x64_F16E5M2E4M3_RS_TN{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 56 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x56x64_F16E5M2E4M3_RS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 48 == 0) { return SM90::GMMA::SPARSE::GMMA_64x48x64_F16E5M2E4M3_RS_TN{}; } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 40 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x40x64_F16E5M2E4M3_RS_TN{}; + } #endif else if constexpr (Tile_N % 32 == 0) { return SM90::GMMA::SPARSE::GMMA_64x32x64_F16E5M2E4M3_RS_TN{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 24 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x24x64_F16E5M2E4M3_RS_TN{}; + } +#endif else if constexpr (Tile_N % 16 == 0) { return SM90::GMMA::SPARSE::GMMA_64x16x64_F16E5M2E4M3_RS_TN{}; } @@ -4681,66 +7731,141 @@ rs_op_selector_sparse() if constexpr (Tile_N % 256 == 0) { return SM90::GMMA::SPARSE::GMMA_64x256x64_F16E5M2E5M2_RS_TN{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 248 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x248x64_F16E5M2E5M2_RS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 240 == 0) { return SM90::GMMA::SPARSE::GMMA_64x240x64_F16E5M2E5M2_RS_TN{}; } #endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 232 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x232x64_F16E5M2E5M2_RS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 224 == 0) { return SM90::GMMA::SPARSE::GMMA_64x224x64_F16E5M2E5M2_RS_TN{}; } #endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 216 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x216x64_F16E5M2E5M2_RS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 208 == 0) { return SM90::GMMA::SPARSE::GMMA_64x208x64_F16E5M2E5M2_RS_TN{}; } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 200 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x200x64_F16E5M2E5M2_RS_TN{}; + } #endif else if constexpr (Tile_N % 192 == 0) { return SM90::GMMA::SPARSE::GMMA_64x192x64_F16E5M2E5M2_RS_TN{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 184 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x184x64_F16E5M2E5M2_RS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 176 == 0) { return SM90::GMMA::SPARSE::GMMA_64x176x64_F16E5M2E5M2_RS_TN{}; } #endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 168 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x168x64_F16E5M2E5M2_RS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 160 == 0) { return SM90::GMMA::SPARSE::GMMA_64x160x64_F16E5M2E5M2_RS_TN{}; } #endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 152 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x152x64_F16E5M2E5M2_RS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 144 == 0) { return SM90::GMMA::SPARSE::GMMA_64x144x64_F16E5M2E5M2_RS_TN{}; } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 136 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x136x64_F16E5M2E5M2_RS_TN{}; + } #endif else if constexpr (Tile_N % 128 == 0) { return SM90::GMMA::SPARSE::GMMA_64x128x64_F16E5M2E5M2_RS_TN{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 120 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x120x64_F16E5M2E5M2_RS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 112 == 0) { return SM90::GMMA::SPARSE::GMMA_64x112x64_F16E5M2E5M2_RS_TN{}; } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 104 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x104x64_F16E5M2E5M2_RS_TN{}; + } #endif else if constexpr (Tile_N % 96 == 0) { return SM90::GMMA::SPARSE::GMMA_64x96x64_F16E5M2E5M2_RS_TN{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 88 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x88x64_F16E5M2E5M2_RS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 80 == 0) { return SM90::GMMA::SPARSE::GMMA_64x80x64_F16E5M2E5M2_RS_TN{}; } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 72 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x72x64_F16E5M2E5M2_RS_TN{}; + } #endif else if constexpr (Tile_N % 64 == 0) { return SM90::GMMA::SPARSE::GMMA_64x64x64_F16E5M2E5M2_RS_TN{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 56 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x56x64_F16E5M2E5M2_RS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 48 == 0) { return SM90::GMMA::SPARSE::GMMA_64x48x64_F16E5M2E5M2_RS_TN{}; } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 40 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x40x64_F16E5M2E5M2_RS_TN{}; + } #endif else if constexpr (Tile_N % 32 == 0) { return SM90::GMMA::SPARSE::GMMA_64x32x64_F16E5M2E5M2_RS_TN{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 24 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x24x64_F16E5M2E5M2_RS_TN{}; + } +#endif else if constexpr (Tile_N % 16 == 0) { return SM90::GMMA::SPARSE::GMMA_64x16x64_F16E5M2E5M2_RS_TN{}; } @@ -4767,66 +7892,141 @@ rs_op_selector_sparse() if constexpr (Tile_N % 256 == 0) { return SM90::GMMA::SPARSE::GMMA_64x256x32_F32F16F16_RS{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 248 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x248x32_F32F16F16_RS{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 240 == 0) { return SM90::GMMA::SPARSE::GMMA_64x240x32_F32F16F16_RS{}; } #endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 232 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x232x32_F32F16F16_RS{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 224 == 0) { return SM90::GMMA::SPARSE::GMMA_64x224x32_F32F16F16_RS{}; } #endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 216 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x216x32_F32F16F16_RS{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 208 == 0) { return SM90::GMMA::SPARSE::GMMA_64x208x32_F32F16F16_RS{}; } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 200 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x200x32_F32F16F16_RS{}; + } #endif else if constexpr (Tile_N % 192 == 0) { return SM90::GMMA::SPARSE::GMMA_64x192x32_F32F16F16_RS{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 184 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x184x32_F32F16F16_RS{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 176 == 0) { return SM90::GMMA::SPARSE::GMMA_64x176x32_F32F16F16_RS{}; } #endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 168 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x168x32_F32F16F16_RS{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 160 == 0) { return SM90::GMMA::SPARSE::GMMA_64x160x32_F32F16F16_RS{}; } #endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 152 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x152x32_F32F16F16_RS{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 144 == 0) { return SM90::GMMA::SPARSE::GMMA_64x144x32_F32F16F16_RS{}; } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 136 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x136x32_F32F16F16_RS{}; + } #endif else if constexpr (Tile_N % 128 == 0) { return SM90::GMMA::SPARSE::GMMA_64x128x32_F32F16F16_RS{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 120 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x120x32_F32F16F16_RS{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 112 == 0) { return SM90::GMMA::SPARSE::GMMA_64x112x32_F32F16F16_RS{}; } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 104 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x104x32_F32F16F16_RS{}; + } #endif else if constexpr (Tile_N % 96 == 0) { return SM90::GMMA::SPARSE::GMMA_64x96x32_F32F16F16_RS{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 88 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x88x32_F32F16F16_RS{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 80 == 0) { return SM90::GMMA::SPARSE::GMMA_64x80x32_F32F16F16_RS{}; } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 72 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x72x32_F32F16F16_RS{}; + } #endif else if constexpr (Tile_N % 64 == 0) { return SM90::GMMA::SPARSE::GMMA_64x64x32_F32F16F16_RS{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 56 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x56x32_F32F16F16_RS{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 48 == 0) { return SM90::GMMA::SPARSE::GMMA_64x48x32_F32F16F16_RS{}; } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 40 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x40x32_F32F16F16_RS{}; + } #endif else if constexpr (Tile_N % 32 == 0) { return SM90::GMMA::SPARSE::GMMA_64x32x32_F32F16F16_RS{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 24 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x24x32_F32F16F16_RS{}; + } +#endif else if constexpr (Tile_N % 16 == 0) { return SM90::GMMA::SPARSE::GMMA_64x16x32_F32F16F16_RS{}; } @@ -4845,66 +8045,141 @@ rs_op_selector_sparse() if constexpr (Tile_N % 256 == 0) { return SM90::GMMA::SPARSE::GMMA_64x256x32_F32BF16BF16_RS{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 248 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x248x32_F32BF16BF16_RS{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 240 == 0) { return SM90::GMMA::SPARSE::GMMA_64x240x32_F32BF16BF16_RS{}; } #endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 232 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x232x32_F32BF16BF16_RS{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 224 == 0) { return SM90::GMMA::SPARSE::GMMA_64x224x32_F32BF16BF16_RS{}; } #endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 216 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x216x32_F32BF16BF16_RS{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 208 == 0) { return SM90::GMMA::SPARSE::GMMA_64x208x32_F32BF16BF16_RS{}; } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 200 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x200x32_F32BF16BF16_RS{}; + } #endif else if constexpr (Tile_N % 192 == 0) { return SM90::GMMA::SPARSE::GMMA_64x192x32_F32BF16BF16_RS{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 184 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x184x32_F32BF16BF16_RS{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 176 == 0) { return SM90::GMMA::SPARSE::GMMA_64x176x32_F32BF16BF16_RS{}; } #endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 168 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x168x32_F32BF16BF16_RS{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 160 == 0) { return SM90::GMMA::SPARSE::GMMA_64x160x32_F32BF16BF16_RS{}; } #endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 152 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x152x32_F32BF16BF16_RS{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 144 == 0) { return SM90::GMMA::SPARSE::GMMA_64x144x32_F32BF16BF16_RS{}; } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 136 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x136x32_F32BF16BF16_RS{}; + } #endif else if constexpr (Tile_N % 128 == 0) { return SM90::GMMA::SPARSE::GMMA_64x128x32_F32BF16BF16_RS{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 120 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x120x32_F32BF16BF16_RS{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 112 == 0) { return SM90::GMMA::SPARSE::GMMA_64x112x32_F32BF16BF16_RS{}; } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 104 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x104x32_F32BF16BF16_RS{}; + } #endif else if constexpr (Tile_N % 96 == 0) { return SM90::GMMA::SPARSE::GMMA_64x96x32_F32BF16BF16_RS{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 88 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x88x32_F32BF16BF16_RS{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 80 == 0) { return SM90::GMMA::SPARSE::GMMA_64x80x32_F32BF16BF16_RS{}; } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 72 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x72x32_F32BF16BF16_RS{}; + } #endif else if constexpr (Tile_N % 64 == 0) { return SM90::GMMA::SPARSE::GMMA_64x64x32_F32BF16BF16_RS{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 56 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x56x32_F32BF16BF16_RS{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 48 == 0) { return SM90::GMMA::SPARSE::GMMA_64x48x32_F32BF16BF16_RS{}; } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 40 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x40x32_F32BF16BF16_RS{}; + } #endif else if constexpr (Tile_N % 32 == 0) { return SM90::GMMA::SPARSE::GMMA_64x32x32_F32BF16BF16_RS{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 24 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x24x32_F32BF16BF16_RS{}; + } +#endif else if constexpr (Tile_N % 16 == 0) { return SM90::GMMA::SPARSE::GMMA_64x16x32_F32BF16BF16_RS{}; } @@ -4925,66 +8200,141 @@ rs_op_selector_sparse() if constexpr (Tile_N % 256 == 0) { return SM90::GMMA::SPARSE::GMMA_64x256x16_F32TF32TF32_RS_TN{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 248 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x248x16_F32TF32TF32_RS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 240 == 0) { return SM90::GMMA::SPARSE::GMMA_64x240x16_F32TF32TF32_RS_TN{}; } #endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 232 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x232x16_F32TF32TF32_RS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 224 == 0) { return SM90::GMMA::SPARSE::GMMA_64x224x16_F32TF32TF32_RS_TN{}; } #endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 216 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x216x16_F32TF32TF32_RS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 208 == 0) { return SM90::GMMA::SPARSE::GMMA_64x208x16_F32TF32TF32_RS_TN{}; } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 200 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x200x16_F32TF32TF32_RS_TN{}; + } #endif else if constexpr (Tile_N % 192 == 0) { return SM90::GMMA::SPARSE::GMMA_64x192x16_F32TF32TF32_RS_TN{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 184 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x184x16_F32TF32TF32_RS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 176 == 0) { return SM90::GMMA::SPARSE::GMMA_64x176x16_F32TF32TF32_RS_TN{}; } #endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 168 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x168x16_F32TF32TF32_RS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 160 == 0) { return SM90::GMMA::SPARSE::GMMA_64x160x16_F32TF32TF32_RS_TN{}; } #endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 152 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x152x16_F32TF32TF32_RS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 144 == 0) { return SM90::GMMA::SPARSE::GMMA_64x144x16_F32TF32TF32_RS_TN{}; } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 136 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x136x16_F32TF32TF32_RS_TN{}; + } #endif else if constexpr (Tile_N % 128 == 0) { return SM90::GMMA::SPARSE::GMMA_64x128x16_F32TF32TF32_RS_TN{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 120 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x120x16_F32TF32TF32_RS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 112 == 0) { return SM90::GMMA::SPARSE::GMMA_64x112x16_F32TF32TF32_RS_TN{}; } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 104 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x104x16_F32TF32TF32_RS_TN{}; + } #endif else if constexpr (Tile_N % 96 == 0) { return SM90::GMMA::SPARSE::GMMA_64x96x16_F32TF32TF32_RS_TN{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 88 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x88x16_F32TF32TF32_RS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 80 == 0) { return SM90::GMMA::SPARSE::GMMA_64x80x16_F32TF32TF32_RS_TN{}; } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 72 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x72x16_F32TF32TF32_RS_TN{}; + } #endif else if constexpr (Tile_N % 64 == 0) { return SM90::GMMA::SPARSE::GMMA_64x64x16_F32TF32TF32_RS_TN{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 56 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x56x16_F32TF32TF32_RS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 48 == 0) { return SM90::GMMA::SPARSE::GMMA_64x48x16_F32TF32TF32_RS_TN{}; } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 40 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x40x16_F32TF32TF32_RS_TN{}; + } #endif else if constexpr (Tile_N % 32 == 0) { return SM90::GMMA::SPARSE::GMMA_64x32x16_F32TF32TF32_RS_TN{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 24 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x24x16_F32TF32TF32_RS_TN{}; + } +#endif else if constexpr (Tile_N % 16 == 0) { return SM90::GMMA::SPARSE::GMMA_64x16x16_F32TF32TF32_RS_TN{}; } @@ -5005,66 +8355,141 @@ rs_op_selector_sparse() if constexpr (Tile_N % 256 == 0) { return SM90::GMMA::SPARSE::GMMA_64x256x64_F32E4M3E4M3_RS_TN{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 248 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x248x64_F32E4M3E4M3_RS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 240 == 0) { return SM90::GMMA::SPARSE::GMMA_64x240x64_F32E4M3E4M3_RS_TN{}; } #endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 232 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x232x64_F32E4M3E4M3_RS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 224 == 0) { return SM90::GMMA::SPARSE::GMMA_64x224x64_F32E4M3E4M3_RS_TN{}; } #endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 216 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x216x64_F32E4M3E4M3_RS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 208 == 0) { return SM90::GMMA::SPARSE::GMMA_64x208x64_F32E4M3E4M3_RS_TN{}; } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 200 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x200x64_F32E4M3E4M3_RS_TN{}; + } #endif else if constexpr (Tile_N % 192 == 0) { return SM90::GMMA::SPARSE::GMMA_64x192x64_F32E4M3E4M3_RS_TN{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 184 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x184x64_F32E4M3E4M3_RS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 176 == 0) { return SM90::GMMA::SPARSE::GMMA_64x176x64_F32E4M3E4M3_RS_TN{}; } #endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 168 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x168x64_F32E4M3E4M3_RS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 160 == 0) { return SM90::GMMA::SPARSE::GMMA_64x160x64_F32E4M3E4M3_RS_TN{}; } #endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 152 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x152x64_F32E4M3E4M3_RS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 144 == 0) { return SM90::GMMA::SPARSE::GMMA_64x144x64_F32E4M3E4M3_RS_TN{}; } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 136 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x136x64_F32E4M3E4M3_RS_TN{}; + } #endif else if constexpr (Tile_N % 128 == 0) { return SM90::GMMA::SPARSE::GMMA_64x128x64_F32E4M3E4M3_RS_TN{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 120 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x120x64_F32E4M3E4M3_RS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 112 == 0) { return SM90::GMMA::SPARSE::GMMA_64x112x64_F32E4M3E4M3_RS_TN{}; } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 104 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x104x64_F32E4M3E4M3_RS_TN{}; + } #endif else if constexpr (Tile_N % 96 == 0) { return SM90::GMMA::SPARSE::GMMA_64x96x64_F32E4M3E4M3_RS_TN{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 88 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x88x64_F32E4M3E4M3_RS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 80 == 0) { return SM90::GMMA::SPARSE::GMMA_64x80x64_F32E4M3E4M3_RS_TN{}; } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 72 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x72x64_F32E4M3E4M3_RS_TN{}; + } #endif else if constexpr (Tile_N % 64 == 0) { return SM90::GMMA::SPARSE::GMMA_64x64x64_F32E4M3E4M3_RS_TN{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 56 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x56x64_F32E4M3E4M3_RS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 48 == 0) { return SM90::GMMA::SPARSE::GMMA_64x48x64_F32E4M3E4M3_RS_TN{}; } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 40 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x40x64_F32E4M3E4M3_RS_TN{}; + } #endif else if constexpr (Tile_N % 32 == 0) { return SM90::GMMA::SPARSE::GMMA_64x32x64_F32E4M3E4M3_RS_TN{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 24 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x24x64_F32E4M3E4M3_RS_TN{}; + } +#endif else if constexpr (Tile_N % 16 == 0) { return SM90::GMMA::SPARSE::GMMA_64x16x64_F32E4M3E4M3_RS_TN{}; } @@ -5085,66 +8510,141 @@ rs_op_selector_sparse() if constexpr (Tile_N % 256 == 0) { return SM90::GMMA::SPARSE::GMMA_64x256x64_F32E4M3E5M2_RS_TN{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 248 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x248x64_F32E4M3E5M2_RS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 240 == 0) { return SM90::GMMA::SPARSE::GMMA_64x240x64_F32E4M3E5M2_RS_TN{}; } #endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 232 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x232x64_F32E4M3E5M2_RS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 224 == 0) { return SM90::GMMA::SPARSE::GMMA_64x224x64_F32E4M3E5M2_RS_TN{}; } #endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 216 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x216x64_F32E4M3E5M2_RS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 208 == 0) { return SM90::GMMA::SPARSE::GMMA_64x208x64_F32E4M3E5M2_RS_TN{}; } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 200 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x200x64_F32E4M3E5M2_RS_TN{}; + } #endif else if constexpr (Tile_N % 192 == 0) { return SM90::GMMA::SPARSE::GMMA_64x192x64_F32E4M3E5M2_RS_TN{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 184 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x184x64_F32E4M3E5M2_RS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 176 == 0) { return SM90::GMMA::SPARSE::GMMA_64x176x64_F32E4M3E5M2_RS_TN{}; } #endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 168 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x168x64_F32E4M3E5M2_RS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 160 == 0) { return SM90::GMMA::SPARSE::GMMA_64x160x64_F32E4M3E5M2_RS_TN{}; } #endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 152 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x152x64_F32E4M3E5M2_RS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 144 == 0) { return SM90::GMMA::SPARSE::GMMA_64x144x64_F32E4M3E5M2_RS_TN{}; } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 136 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x136x64_F32E4M3E5M2_RS_TN{}; + } #endif else if constexpr (Tile_N % 128 == 0) { return SM90::GMMA::SPARSE::GMMA_64x128x64_F32E4M3E5M2_RS_TN{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 120 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x120x64_F32E4M3E5M2_RS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 112 == 0) { return SM90::GMMA::SPARSE::GMMA_64x112x64_F32E4M3E5M2_RS_TN{}; } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 104 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x104x64_F32E4M3E5M2_RS_TN{}; + } #endif else if constexpr (Tile_N % 96 == 0) { return SM90::GMMA::SPARSE::GMMA_64x96x64_F32E4M3E5M2_RS_TN{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 88 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x88x64_F32E4M3E5M2_RS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 80 == 0) { return SM90::GMMA::SPARSE::GMMA_64x80x64_F32E4M3E5M2_RS_TN{}; } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 72 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x72x64_F32E4M3E5M2_RS_TN{}; + } #endif else if constexpr (Tile_N % 64 == 0) { return SM90::GMMA::SPARSE::GMMA_64x64x64_F32E4M3E5M2_RS_TN{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 56 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x56x64_F32E4M3E5M2_RS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 48 == 0) { return SM90::GMMA::SPARSE::GMMA_64x48x64_F32E4M3E5M2_RS_TN{}; } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 40 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x40x64_F32E4M3E5M2_RS_TN{}; + } #endif else if constexpr (Tile_N % 32 == 0) { return SM90::GMMA::SPARSE::GMMA_64x32x64_F32E4M3E5M2_RS_TN{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 24 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x24x64_F32E4M3E5M2_RS_TN{}; + } +#endif else if constexpr (Tile_N % 16 == 0) { return SM90::GMMA::SPARSE::GMMA_64x16x64_F32E4M3E5M2_RS_TN{}; } @@ -5165,66 +8665,141 @@ rs_op_selector_sparse() if constexpr (Tile_N % 256 == 0) { return SM90::GMMA::SPARSE::GMMA_64x256x64_F32E5M2E4M3_RS_TN{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 248 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x248x64_F32E5M2E4M3_RS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 240 == 0) { return SM90::GMMA::SPARSE::GMMA_64x240x64_F32E5M2E4M3_RS_TN{}; } #endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 232 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x232x64_F32E5M2E4M3_RS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 224 == 0) { return SM90::GMMA::SPARSE::GMMA_64x224x64_F32E5M2E4M3_RS_TN{}; } #endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 216 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x216x64_F32E5M2E4M3_RS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 208 == 0) { return SM90::GMMA::SPARSE::GMMA_64x208x64_F32E5M2E4M3_RS_TN{}; } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 200 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x200x64_F32E5M2E4M3_RS_TN{}; + } #endif else if constexpr (Tile_N % 192 == 0) { return SM90::GMMA::SPARSE::GMMA_64x192x64_F32E5M2E4M3_RS_TN{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 184 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x184x64_F32E5M2E4M3_RS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 176 == 0) { return SM90::GMMA::SPARSE::GMMA_64x176x64_F32E5M2E4M3_RS_TN{}; } #endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 168 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x168x64_F32E5M2E4M3_RS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 160 == 0) { return SM90::GMMA::SPARSE::GMMA_64x160x64_F32E5M2E4M3_RS_TN{}; } #endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 152 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x152x64_F32E5M2E4M3_RS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 144 == 0) { return SM90::GMMA::SPARSE::GMMA_64x144x64_F32E5M2E4M3_RS_TN{}; } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 136 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x136x64_F32E5M2E4M3_RS_TN{}; + } #endif else if constexpr (Tile_N % 128 == 0) { return SM90::GMMA::SPARSE::GMMA_64x128x64_F32E5M2E4M3_RS_TN{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 120 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x120x64_F32E5M2E4M3_RS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 112 == 0) { return SM90::GMMA::SPARSE::GMMA_64x112x64_F32E5M2E4M3_RS_TN{}; } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 104 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x104x64_F32E5M2E4M3_RS_TN{}; + } #endif else if constexpr (Tile_N % 96 == 0) { return SM90::GMMA::SPARSE::GMMA_64x96x64_F32E5M2E4M3_RS_TN{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 88 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x88x64_F32E5M2E4M3_RS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 80 == 0) { return SM90::GMMA::SPARSE::GMMA_64x80x64_F32E5M2E4M3_RS_TN{}; } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 72 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x72x64_F32E5M2E4M3_RS_TN{}; + } #endif else if constexpr (Tile_N % 64 == 0) { return SM90::GMMA::SPARSE::GMMA_64x64x64_F32E5M2E4M3_RS_TN{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 56 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x56x64_F32E5M2E4M3_RS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 48 == 0) { return SM90::GMMA::SPARSE::GMMA_64x48x64_F32E5M2E4M3_RS_TN{}; } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 40 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x40x64_F32E5M2E4M3_RS_TN{}; + } #endif else if constexpr (Tile_N % 32 == 0) { return SM90::GMMA::SPARSE::GMMA_64x32x64_F32E5M2E4M3_RS_TN{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 24 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x24x64_F32E5M2E4M3_RS_TN{}; + } +#endif else if constexpr (Tile_N % 16 == 0) { return SM90::GMMA::SPARSE::GMMA_64x16x64_F32E5M2E4M3_RS_TN{}; } @@ -5245,66 +8820,141 @@ rs_op_selector_sparse() if constexpr (Tile_N % 256 == 0) { return SM90::GMMA::SPARSE::GMMA_64x256x64_F32E5M2E5M2_RS_TN{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 248 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x248x64_F32E5M2E5M2_RS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 240 == 0) { return SM90::GMMA::SPARSE::GMMA_64x240x64_F32E5M2E5M2_RS_TN{}; } #endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 232 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x232x64_F32E5M2E5M2_RS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 224 == 0) { return SM90::GMMA::SPARSE::GMMA_64x224x64_F32E5M2E5M2_RS_TN{}; } #endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 216 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x216x64_F32E5M2E5M2_RS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 208 == 0) { return SM90::GMMA::SPARSE::GMMA_64x208x64_F32E5M2E5M2_RS_TN{}; } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 200 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x200x64_F32E5M2E5M2_RS_TN{}; + } #endif else if constexpr (Tile_N % 192 == 0) { return SM90::GMMA::SPARSE::GMMA_64x192x64_F32E5M2E5M2_RS_TN{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 184 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x184x64_F32E5M2E5M2_RS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 176 == 0) { return SM90::GMMA::SPARSE::GMMA_64x176x64_F32E5M2E5M2_RS_TN{}; } #endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 168 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x168x64_F32E5M2E5M2_RS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 160 == 0) { return SM90::GMMA::SPARSE::GMMA_64x160x64_F32E5M2E5M2_RS_TN{}; } #endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 152 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x152x64_F32E5M2E5M2_RS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 144 == 0) { return SM90::GMMA::SPARSE::GMMA_64x144x64_F32E5M2E5M2_RS_TN{}; } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 136 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x136x64_F32E5M2E5M2_RS_TN{}; + } #endif else if constexpr (Tile_N % 128 == 0) { return SM90::GMMA::SPARSE::GMMA_64x128x64_F32E5M2E5M2_RS_TN{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 120 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x120x64_F32E5M2E5M2_RS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 112 == 0) { return SM90::GMMA::SPARSE::GMMA_64x112x64_F32E5M2E5M2_RS_TN{}; } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 104 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x104x64_F32E5M2E5M2_RS_TN{}; + } #endif else if constexpr (Tile_N % 96 == 0) { return SM90::GMMA::SPARSE::GMMA_64x96x64_F32E5M2E5M2_RS_TN{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 88 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x88x64_F32E5M2E5M2_RS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 80 == 0) { return SM90::GMMA::SPARSE::GMMA_64x80x64_F32E5M2E5M2_RS_TN{}; } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 72 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x72x64_F32E5M2E5M2_RS_TN{}; + } #endif else if constexpr (Tile_N % 64 == 0) { return SM90::GMMA::SPARSE::GMMA_64x64x64_F32E5M2E5M2_RS_TN{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 56 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x56x64_F32E5M2E5M2_RS_TN{}; + } +#endif #if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) else if constexpr (Tile_N % 48 == 0) { return SM90::GMMA::SPARSE::GMMA_64x48x64_F32E5M2E5M2_RS_TN{}; } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 40 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x40x64_F32E5M2E5M2_RS_TN{}; + } #endif else if constexpr (Tile_N % 32 == 0) { return SM90::GMMA::SPARSE::GMMA_64x32x64_F32E5M2E5M2_RS_TN{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 24 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x24x64_F32E5M2E5M2_RS_TN{}; + } +#endif else if constexpr (Tile_N % 16 == 0) { return SM90::GMMA::SPARSE::GMMA_64x16x64_F32E5M2E5M2_RS_TN{}; } @@ -5393,6 +9043,11 @@ rs_op_selector_sparse() else if constexpr (Tile_N % 32 == 0) { return SM90::GMMA::SPARSE::GMMA_64x32x64_S32S8S8_RS_TN{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 24 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x24x64_S32S8S8_RS_TN{}; + } +#endif else if constexpr (Tile_N % 16 == 0) { return SM90::GMMA::SPARSE::GMMA_64x16x64_S32S8S8_RS_TN{}; } @@ -5473,6 +9128,11 @@ rs_op_selector_sparse() else if constexpr (Tile_N % 32 == 0) { return SM90::GMMA::SPARSE::GMMA_64x32x64_S32S8U8_RS_TN{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 24 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x24x64_S32S8U8_RS_TN{}; + } +#endif else if constexpr (Tile_N % 16 == 0) { return SM90::GMMA::SPARSE::GMMA_64x16x64_S32S8U8_RS_TN{}; } @@ -5553,6 +9213,11 @@ rs_op_selector_sparse() else if constexpr (Tile_N % 32 == 0) { return SM90::GMMA::SPARSE::GMMA_64x32x64_S32U8S8_RS_TN{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 24 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x24x64_S32U8S8_RS_TN{}; + } +#endif else if constexpr (Tile_N % 16 == 0) { return SM90::GMMA::SPARSE::GMMA_64x16x64_S32U8S8_RS_TN{}; } @@ -5633,6 +9298,11 @@ rs_op_selector_sparse() else if constexpr (Tile_N % 32 == 0) { return SM90::GMMA::SPARSE::GMMA_64x32x64_S32U8U8_RS_TN{}; } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 24 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x24x64_S32U8U8_RS_TN{}; + } +#endif else if constexpr (Tile_N % 16 == 0) { return SM90::GMMA::SPARSE::GMMA_64x16x64_S32U8U8_RS_TN{}; } diff --git a/include/cute/arch/mma_sm90_gmma.hpp b/include/cute/arch/mma_sm90_gmma.hpp index 1213823bc..d809aa4a6 100644 --- a/include/cute/arch/mma_sm90_gmma.hpp +++ b/include/cute/arch/mma_sm90_gmma.hpp @@ -406,111 +406,6 @@ struct MMA_64x32x16_F16F16F16_RS //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x48x16 F16+=F16*F16 -template < - GMMA::Major tnspA, - GMMA::Major tnspB, - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -struct MMA_64x48x16_F16F16F16_SS -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[12]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %14, 0;\n" - "wgmma.mma_async.sync.aligned.m64n48k16.f16.f16.f16 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11}," - " %12," - " %13," - " p, %15, %16, %17, %18;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11) - : "l"(desc_a), - "l"(desc_b), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x48x16_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x48x16 F16+=F16*F16 -template < - GMMA::Major tnspA, - GMMA::Major tnspB, - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -struct MMA_64x48x16_F16F16F16_RS -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[12]; - - static_assert(tnspA == GMMA::Major::K, - "Register source operand A must have K major layout."); - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %17, 0;\n" - "wgmma.mma_async.sync.aligned.m64n48k16.f16.f16.f16 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11}," - "{%12, %13, %14, %15}," - " %16," - " p, %18, %19, %20;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), - "l"(desc_b), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x48x16_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - // GMMA 64x64x16 F16+=F16*F16 template < GMMA::Major tnspA, @@ -616,121 +511,6 @@ struct MMA_64x64x16_F16F16F16_RS //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x80x16 F16+=F16*F16 -template < - GMMA::Major tnspA, - GMMA::Major tnspB, - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -struct MMA_64x80x16_F16F16F16_SS -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[20]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %22, 0;\n" - "wgmma.mma_async.sync.aligned.m64n80k16.f16.f16.f16 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19}," - " %20," - " %21," - " p, %23, %24, %25, %26;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19) - : "l"(desc_a), - "l"(desc_b), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x80x16_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x80x16 F16+=F16*F16 -template < - GMMA::Major tnspA, - GMMA::Major tnspB, - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -struct MMA_64x80x16_F16F16F16_RS -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[20]; - - static_assert(tnspA == GMMA::Major::K, - "Register source operand A must have K major layout."); - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %25, 0;\n" - "wgmma.mma_async.sync.aligned.m64n80k16.f16.f16.f16 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19}," - "{%20, %21, %22, %23}," - " %24," - " p, %26, %27, %28;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), - "l"(desc_b), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x80x16_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - // GMMA 64x96x16 F16+=F16*F16 template < GMMA::Major tnspA, @@ -846,131 +626,6 @@ struct MMA_64x96x16_F16F16F16_RS //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x112x16 F16+=F16*F16 -template < - GMMA::Major tnspA, - GMMA::Major tnspB, - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -struct MMA_64x112x16_F16F16F16_SS -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[28]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %30, 0;\n" - "wgmma.mma_async.sync.aligned.m64n112k16.f16.f16.f16 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27}," - " %28," - " %29," - " p, %31, %32, %33, %34;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27) - : "l"(desc_a), - "l"(desc_b), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x112x16_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x112x16 F16+=F16*F16 -template < - GMMA::Major tnspA, - GMMA::Major tnspB, - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -struct MMA_64x112x16_F16F16F16_RS -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[28]; - - static_assert(tnspA == GMMA::Major::K, - "Register source operand A must have K major layout."); - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %33, 0;\n" - "wgmma.mma_async.sync.aligned.m64n112k16.f16.f16.f16 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27}," - "{%28, %29, %30, %31}," - " %32," - " p, %34, %35, %36;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), - "l"(desc_b), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x112x16_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - // GMMA 64x128x16 F16+=F16*F16 template < GMMA::Major tnspA, @@ -1096,20 +751,19 @@ struct MMA_64x128x16_F16F16F16_RS //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x144x16 F16+=F16*F16 +// GMMA 64x192x16 F16+=F16*F16 template < GMMA::Major tnspA, GMMA::Major tnspB, GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct MMA_64x144x16_F16F16F16_SS +struct MMA_64x192x16_F16F16F16_SS { using DRegisters = void; using ARegisters = uint64_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[36]; + using CRegisters = uint32_t[48]; CUTE_HOST_DEVICE static void fma(uint64_t const& desc_a, @@ -1123,6 +777,9 @@ struct MMA_64x144x16_F16F16F16_SS uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) @@ -1130,16 +787,17 @@ struct MMA_64x144x16_F16F16F16_SS asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %38, 0;\n" - "wgmma.mma_async.sync.aligned.m64n144k16.f16.f16.f16 " + "setp.ne.b32 p, %50, 0;\n" + "wgmma.mma_async.sync.aligned.m64n192k16.f16.f16.f16 " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " " %16, %17, %18, %19, %20, %21, %22, %23, " " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35}," - " %36," - " %37," - " p, %39, %40, %41, %42;\n" + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + " %48," + " %49," + " p, %51, %52, %53, %54;\n" "}\n" : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), @@ -1149,33 +807,34 @@ struct MMA_64x144x16_F16F16F16_SS "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35) + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) : "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x144x16_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x192x16_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x144x16 F16+=F16*F16 +// GMMA 64x192x16 F16+=F16*F16 template < GMMA::Major tnspA, GMMA::Major tnspB, GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct MMA_64x144x16_F16F16F16_RS +struct MMA_64x192x16_F16F16F16_RS { using DRegisters = void; using ARegisters = uint32_t[4]; using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[36]; + using CRegisters = uint32_t[48]; static_assert(tnspA == GMMA::Major::K, "Register source operand A must have K major layout."); @@ -1192,6 +851,9 @@ struct MMA_64x144x16_F16F16F16_RS uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) @@ -1199,16 +861,17 @@ struct MMA_64x144x16_F16F16F16_RS asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %41, 0;\n" - "wgmma.mma_async.sync.aligned.m64n144k16.f16.f16.f16 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35}," - "{%36, %37, %38, %39}," - " %40," - " p, %42, %43, %44;\n" + "setp.ne.b32 p, %53, 0;\n" + "wgmma.mma_async.sync.aligned.m64n192k16.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + "{%48, %49, %50, %51}," + " %52," + " p, %54, %55, %56;\n" "}\n" : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), @@ -1218,33 +881,34 @@ struct MMA_64x144x16_F16F16F16_RS "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35) + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) : "r"(a00), "r"(a01), "r"(a02), "r"(a03), "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x144x16_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x192x16_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x160x16 F16+=F16*F16 +// GMMA 64x256x16 F16+=F16*F16 template < GMMA::Major tnspA, GMMA::Major tnspB, GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct MMA_64x160x16_F16F16F16_SS +struct MMA_64x256x16_F16F16F16_SS { using DRegisters = void; using ARegisters = uint64_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[40]; + using CRegisters = uint32_t[64]; CUTE_HOST_DEVICE static void fma(uint64_t const& desc_a, @@ -1259,6 +923,12 @@ struct MMA_64x160x16_F16F16F16_SS uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) @@ -1266,16 +936,19 @@ struct MMA_64x160x16_F16F16F16_SS asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %42, 0;\n" - "wgmma.mma_async.sync.aligned.m64n160k16.f16.f16.f16 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39}," - " %40," - " %41," - " p, %43, %44, %45, %46;\n" + "setp.ne.b32 p, %66, 0;\n" + "wgmma.mma_async.sync.aligned.m64n256k16.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + " %64," + " %65," + " p, %67, %68, %69, %70;\n" "}\n" : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), @@ -1286,33 +959,37 @@ struct MMA_64x160x16_F16F16F16_SS "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) : "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x160x16_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x256x16_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x160x16 F16+=F16*F16 +// GMMA 64x256x16 F16+=F16*F16 template < GMMA::Major tnspA, GMMA::Major tnspB, GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct MMA_64x160x16_F16F16F16_RS +struct MMA_64x256x16_F16F16F16_RS { using DRegisters = void; using ARegisters = uint32_t[4]; using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[40]; + using CRegisters = uint32_t[64]; static_assert(tnspA == GMMA::Major::K, "Register source operand A must have K major layout."); @@ -1330,6 +1007,12 @@ struct MMA_64x160x16_F16F16F16_RS uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) @@ -1337,16 +1020,19 @@ struct MMA_64x160x16_F16F16F16_RS asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %45, 0;\n" - "wgmma.mma_async.sync.aligned.m64n160k16.f16.f16.f16 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39}," - "{%40, %41, %42, %43}," - " %44," - " p, %46, %47, %48;\n" + "setp.ne.b32 p, %69, 0;\n" + "wgmma.mma_async.sync.aligned.m64n256k16.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + "{%64, %65, %66, %67}," + " %68," + " p, %70, %71, %72;\n" "}\n" : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), @@ -1357,48 +1043,42 @@ struct MMA_64x160x16_F16F16F16_RS "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) : "r"(a00), "r"(a01), "r"(a02), "r"(a03), "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x160x16_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x256x16_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x176x16 F16+=F16*F16 +// GMMA 64x8x16 F32+=F16*F16 template < GMMA::Major tnspA, GMMA::Major tnspB, GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct MMA_64x176x16_F16F16F16_SS +struct MMA_64x8x16_F32F16F16_SS { using DRegisters = void; using ARegisters = uint64_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[44]; + using CRegisters = float[4]; CUTE_HOST_DEVICE static void fma(uint64_t const& desc_a, uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + float & d0, float & d1, float & d2, float & d3, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) @@ -1406,73 +1086,46 @@ struct MMA_64x176x16_F16F16F16_SS asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %46, 0;\n" - "wgmma.mma_async.sync.aligned.m64n176k16.f16.f16.f16 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43}," - " %44," - " %45," - " p, %47, %48, %49, %50;\n" + "setp.ne.b32 p, %6, 0;\n" + "wgmma.mma_async.sync.aligned.m64n8k16.f32.f16.f16 " + "{%0, %1, %2, %3}," + " %4," + " %5," + " p, %7, %8, %9, %10;\n" "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43) + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3) : "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x176x16_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x8x16_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x176x16 F16+=F16*F16 +// GMMA 64x8x16 F32+=F16*F16 template < GMMA::Major tnspA, GMMA::Major tnspB, GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct MMA_64x176x16_F16F16F16_RS +struct MMA_64x8x16_F32F16F16_RS { using DRegisters = void; using ARegisters = uint32_t[4]; using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[44]; + using CRegisters = float[4]; static_assert(tnspA == GMMA::Major::K, "Register source operand A must have K major layout."); CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + float & d0, float & d1, float & d2, float & d3, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) @@ -1480,70 +1133,44 @@ struct MMA_64x176x16_F16F16F16_RS asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %49, 0;\n" - "wgmma.mma_async.sync.aligned.m64n176k16.f16.f16.f16 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43}," - "{%44, %45, %46, %47}," - " %48," - " p, %50, %51, %52;\n" + "setp.ne.b32 p, %9, 0;\n" + "wgmma.mma_async.sync.aligned.m64n8k16.f32.f16.f16 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + " %8," + " p, %10, %11, %12;\n" "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x176x16_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x8x16_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -// GMMA 64x192x16 F16+=F16*F16 +// GMMA 64x16x16 F32+=F16*F16 template < GMMA::Major tnspA, GMMA::Major tnspB, GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct MMA_64x192x16_F16F16F16_SS +struct MMA_64x16x16_F32F16F16_SS { using DRegisters = void; using ARegisters = uint64_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[48]; + using CRegisters = float[8]; CUTE_HOST_DEVICE static void fma(uint64_t const& desc_a, uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + float & d0, float & d1, float & d2, float & d3, + float & d4, float & d5, float & d6, float & d7, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) @@ -1551,73 +1178,48 @@ struct MMA_64x192x16_F16F16F16_SS asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %50, 0;\n" - "wgmma.mma_async.sync.aligned.m64n192k16.f16.f16.f16 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47}," - " %48," - " %49," - " p, %51, %52, %53, %54;\n" + "setp.ne.b32 p, %10, 0;\n" + "wgmma.mma_async.sync.aligned.m64n16k16.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + " %8," + " %9," + " p, %11, %12, %13, %14;\n" "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), - "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3), + "+f"(d4), "+f"(d5), "+f"(d6), "+f"(d7) : "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x192x16_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x16x16_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// GMMA 64x192x16 F16+=F16*F16 +// GMMA 64x16x16 F32+=F16*F16 template < GMMA::Major tnspA, GMMA::Major tnspB, GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct MMA_64x192x16_F16F16F16_RS +struct MMA_64x16x16_F32F16F16_RS { using DRegisters = void; using ARegisters = uint32_t[4]; using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[48]; + using CRegisters = float[8]; static_assert(tnspA == GMMA::Major::K, "Register source operand A must have K major layout."); CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + float & d0, float & d1, float & d2, float & d3, + float & d4, float & d5, float & d6, float & d7, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) @@ -1625,72 +1227,47 @@ struct MMA_64x192x16_F16F16F16_RS asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %53, 0;\n" - "wgmma.mma_async.sync.aligned.m64n192k16.f16.f16.f16 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47}," - "{%48, %49, %50, %51}," - " %52," - " p, %54, %55, %56;\n" + "setp.ne.b32 p, %13, 0;\n" + "wgmma.mma_async.sync.aligned.m64n16k16.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + "{%8, %9, %10, %11}," + " %12," + " p, %14, %15, %16;\n" "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), - "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3), + "+f"(d4), "+f"(d5), "+f"(d6), "+f"(d7) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x192x16_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x16x16_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x208x16 F16+=F16*F16 +// GMMA 64x32x16 F32+=F16*F16 template < GMMA::Major tnspA, GMMA::Major tnspB, GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct MMA_64x208x16_F16F16F16_SS +struct MMA_64x32x16_F32F16F16_SS { using DRegisters = void; using ARegisters = uint64_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[52]; + using CRegisters = float[16]; CUTE_HOST_DEVICE static void fma(uint64_t const& desc_a, uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, - uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) @@ -1698,58 +1275,42 @@ struct MMA_64x208x16_F16F16F16_SS asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %54, 0;\n" - "wgmma.mma_async.sync.aligned.m64n208k16.f16.f16.f16 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51}," - " %52," - " %53," - " p, %55, %56, %57, %58;\n" + "setp.ne.b32 p, %18, 0;\n" + "wgmma.mma_async.sync.aligned.m64n32k16.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + " %16," + " %17," + " p, %19, %20, %21, %22;\n" "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), - "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), - "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51) + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15) : "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x208x16_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x32x16_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x208x16 F16+=F16*F16 +// GMMA 64x32x16 F32+=F16*F16 template < GMMA::Major tnspA, GMMA::Major tnspB, GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct MMA_64x208x16_F16F16F16_RS +struct MMA_64x32x16_F32F16F16_RS { using DRegisters = void; using ARegisters = uint32_t[4]; using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[52]; + using CRegisters = float[16]; static_assert(tnspA == GMMA::Major::K, "Register source operand A must have K major layout."); @@ -1757,19 +1318,10 @@ struct MMA_64x208x16_F16F16F16_RS CUTE_HOST_DEVICE static void fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, - uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) @@ -1777,76 +1329,54 @@ struct MMA_64x208x16_F16F16F16_RS asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %57, 0;\n" - "wgmma.mma_async.sync.aligned.m64n208k16.f16.f16.f16 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51}," - "{%52, %53, %54, %55}," - " %56," - " p, %58, %59, %60;\n" + "setp.ne.b32 p, %21, 0;\n" + "wgmma.mma_async.sync.aligned.m64n32k16.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + "{%16, %17, %18, %19}," + " %20," + " p, %22, %23, %24;\n" "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), - "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), - "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51) + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15) : "r"(a00), "r"(a01), "r"(a02), "r"(a03), "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x208x16_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x32x16_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x224x16 F16+=F16*F16 +// GMMA 64x64x16 F32+=F16*F16 template < GMMA::Major tnspA, GMMA::Major tnspB, GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct MMA_64x224x16_F16F16F16_SS +struct MMA_64x64x16_F32F16F16_SS { using DRegisters = void; using ARegisters = uint64_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[56]; + using CRegisters = float[32]; CUTE_HOST_DEVICE static void fma(uint64_t const& desc_a, uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, - uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, - uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) @@ -1854,59 +1384,48 @@ struct MMA_64x224x16_F16F16F16_SS asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %58, 0;\n" - "wgmma.mma_async.sync.aligned.m64n224k16.f16.f16.f16 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55}," - " %56," - " %57," - " p, %59, %60, %61, %62;\n" + "setp.ne.b32 p, %34, 0;\n" + "wgmma.mma_async.sync.aligned.m64n64k16.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + " %32," + " %33," + " p, %35, %36, %37, %38;\n" "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), - "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), - "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), - "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31) : "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x224x16_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x64x16_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x224x16 F16+=F16*F16 +// GMMA 64x64x16 F32+=F16*F16 template < GMMA::Major tnspA, GMMA::Major tnspB, GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct MMA_64x224x16_F16F16F16_RS +struct MMA_64x64x16_F32F16F16_RS { using DRegisters = void; using ARegisters = uint32_t[4]; using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[56]; + using CRegisters = float[32]; static_assert(tnspA == GMMA::Major::K, "Register source operand A must have K major layout."); @@ -1914,20 +1433,14 @@ struct MMA_64x224x16_F16F16F16_RS CUTE_HOST_DEVICE static void fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, - uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, - uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) @@ -1935,78 +1448,64 @@ struct MMA_64x224x16_F16F16F16_RS asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %61, 0;\n" - "wgmma.mma_async.sync.aligned.m64n224k16.f16.f16.f16 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55}," - "{%56, %57, %58, %59}," - " %60," - " p, %62, %63, %64;\n" + "setp.ne.b32 p, %37, 0;\n" + "wgmma.mma_async.sync.aligned.m64n64k16.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + "{%32, %33, %34, %35}," + " %36," + " p, %38, %39, %40;\n" "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), - "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), - "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), - "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31) : "r"(a00), "r"(a01), "r"(a02), "r"(a03), "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x224x16_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x64x16_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x240x16 F16+=F16*F16 +// GMMA 64x96x16 F32+=F16*F16 template < GMMA::Major tnspA, GMMA::Major tnspB, GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct MMA_64x240x16_F16F16F16_SS +struct MMA_64x96x16_F32F16F16_SS { using DRegisters = void; using ARegisters = uint64_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[60]; + using CRegisters = float[48]; CUTE_HOST_DEVICE static void fma(uint64_t const& desc_a, uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, - uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, - uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, - uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) @@ -2014,61 +1513,54 @@ struct MMA_64x240x16_F16F16F16_SS asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %62, 0;\n" - "wgmma.mma_async.sync.aligned.m64n240k16.f16.f16.f16 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59}," - " %60," - " %61," - " p, %63, %64, %65, %66;\n" + "setp.ne.b32 p, %50, 0;\n" + "wgmma.mma_async.sync.aligned.m64n96k16.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + " %48," + " %49," + " p, %51, %52, %53, %54;\n" "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), - "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), - "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), - "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), - "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59) + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47) : "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x240x16_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x96x16_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x240x16 F16+=F16*F16 +// GMMA 64x96x16 F32+=F16*F16 template < GMMA::Major tnspA, GMMA::Major tnspB, GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct MMA_64x240x16_F16F16F16_RS +struct MMA_64x96x16_F32F16F16_RS { using DRegisters = void; using ARegisters = uint32_t[4]; using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[60]; + using CRegisters = float[48]; static_assert(tnspA == GMMA::Major::K, "Register source operand A must have K major layout."); @@ -2076,21 +1568,18 @@ struct MMA_64x240x16_F16F16F16_RS CUTE_HOST_DEVICE static void fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, - uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, - uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, - uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) @@ -2098,80 +1587,74 @@ struct MMA_64x240x16_F16F16F16_RS asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %65, 0;\n" - "wgmma.mma_async.sync.aligned.m64n240k16.f16.f16.f16 " + "setp.ne.b32 p, %53, 0;\n" + "wgmma.mma_async.sync.aligned.m64n96k16.f32.f16.f16 " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " " %16, %17, %18, %19, %20, %21, %22, %23, " " %24, %25, %26, %27, %28, %29, %30, %31, " " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59}," - "{%60, %61, %62, %63}," - " %64," - " p, %66, %67, %68;\n" + " %40, %41, %42, %43, %44, %45, %46, %47}," + "{%48, %49, %50, %51}," + " %52," + " p, %54, %55, %56;\n" "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), - "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), - "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), - "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), - "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59) + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47) : "r"(a00), "r"(a01), "r"(a02), "r"(a03), "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x240x16_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x96x16_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -// GMMA 64x256x16 F16+=F16*F16 +// GMMA 64x128x16 F32+=F16*F16 template < GMMA::Major tnspA, GMMA::Major tnspB, GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct MMA_64x256x16_F16F16F16_SS +struct MMA_64x128x16_F32F16F16_SS { using DRegisters = void; using ARegisters = uint64_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[64]; + using CRegisters = float[64]; CUTE_HOST_DEVICE static void fma(uint64_t const& desc_a, uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, - uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, - uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, - uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, - uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) @@ -2180,7 +1663,7 @@ struct MMA_64x256x16_F16F16F16_SS "{\n" ".reg .pred p;\n" "setp.ne.b32 p, %66, 0;\n" - "wgmma.mma_async.sync.aligned.m64n256k16.f16.f16.f16 " + "wgmma.mma_async.sync.aligned.m64n128k16.f32.f16.f16 " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " " %16, %17, %18, %19, %20, %21, %22, %23, " @@ -2193,46 +1676,46 @@ struct MMA_64x256x16_F16F16F16_SS " %65," " p, %67, %68, %69, %70;\n" "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), - "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), - "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), - "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), - "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), - "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63) : "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x256x16_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x128x16_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// GMMA 64x256x16 F16+=F16*F16 +// GMMA 64x128x16 F32+=F16*F16 template < GMMA::Major tnspA, GMMA::Major tnspB, GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct MMA_64x256x16_F16F16F16_RS +struct MMA_64x128x16_F32F16F16_RS { using DRegisters = void; using ARegisters = uint32_t[4]; using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[64]; + using CRegisters = float[64]; static_assert(tnspA == GMMA::Major::K, "Register source operand A must have K major layout."); @@ -2240,22 +1723,22 @@ struct MMA_64x256x16_F16F16F16_RS CUTE_HOST_DEVICE static void fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, - uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, - uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, - uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, - uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) @@ -2264,7 +1747,7 @@ struct MMA_64x256x16_F16F16F16_RS "{\n" ".reg .pred p;\n" "setp.ne.b32 p, %69, 0;\n" - "wgmma.mma_async.sync.aligned.m64n256k16.f16.f16.f16 " + "wgmma.mma_async.sync.aligned.m64n128k16.f32.f16.f16 " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " " %16, %17, %18, %19, %20, %21, %22, %23, " @@ -2277,51 +1760,74 @@ struct MMA_64x256x16_F16F16F16_RS " %68," " p, %70, %71, %72;\n" "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), - "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), - "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), - "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), - "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), - "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63) : "r"(a00), "r"(a01), "r"(a02), "r"(a03), "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x256x16_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x128x16_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// GMMA 64x8x16 F32+=F16*F16 +// GMMA 64x192x16 F32+=F16*F16 template < GMMA::Major tnspA, GMMA::Major tnspB, GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct MMA_64x8x16_F32F16F16_SS +struct MMA_64x192x16_F32F16F16_SS { using DRegisters = void; using ARegisters = uint64_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = float[4]; + using CRegisters = float[96]; CUTE_HOST_DEVICE static void fma(uint64_t const& desc_a, uint64_t const& desc_b, - float & d0, float & d1, float & d2, float & d3, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + float & d88, float & d89, float & d90, float & d91, + float & d92, float & d93, float & d94, float & d95, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) @@ -2329,46 +1835,103 @@ struct MMA_64x8x16_F32F16F16_SS asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %6, 0;\n" - "wgmma.mma_async.sync.aligned.m64n8k16.f32.f16.f16 " - "{%0, %1, %2, %3}," - " %4," - " %5," - " p, %7, %8, %9, %10;\n" + "setp.ne.b32 p, %98, 0;\n" + "wgmma.mma_async.sync.aligned.m64n192k16.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + " %96," + " %97," + " p, %99, %100, %101, %102;\n" "}\n" - : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3) + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), + "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91), + "+f"(d92), "+f"(d93), "+f"(d94), "+f"(d95) : "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x8x16_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x192x16_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// GMMA 64x8x16 F32+=F16*F16 +// GMMA 64x192x16 F32+=F16*F16 template < GMMA::Major tnspA, GMMA::Major tnspB, GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct MMA_64x8x16_F32F16F16_RS +struct MMA_64x192x16_F32F16F16_RS { using DRegisters = void; using ARegisters = uint32_t[4]; using BRegisters = uint64_t[1]; - using CRegisters = float[4]; + using CRegisters = float[96]; static_assert(tnspA == GMMA::Major::K, "Register source operand A must have K major layout."); CUTE_HOST_DEVICE static void - fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, uint64_t const& desc_b, - float & d0, float & d1, float & d2, float & d3, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + float & d88, float & d89, float & d90, float & d91, + float & d92, float & d93, float & d94, float & d95, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) @@ -2376,44 +1939,108 @@ struct MMA_64x8x16_F32F16F16_RS asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %9, 0;\n" - "wgmma.mma_async.sync.aligned.m64n8k16.f32.f16.f16 " - "{%0, %1, %2, %3}," - "{%4, %5, %6, %7}," - " %8," - " p, %10, %11, %12;\n" + "setp.ne.b32 p, %101, 0;\n" + "wgmma.mma_async.sync.aligned.m64n192k16.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + "{%96, %97, %98, %99}," + " %100," + " p, %102, %103, %104;\n" "}\n" - : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3) - : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), + "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91), + "+f"(d92), "+f"(d93), "+f"(d94), "+f"(d95) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x8x16_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x192x16_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// GMMA 64x16x16 F32+=F16*F16 +// GMMA 64x256x16 F32+=F16*F16 template < GMMA::Major tnspA, GMMA::Major tnspB, GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct MMA_64x16x16_F32F16F16_SS +struct MMA_64x256x16_F32F16F16_SS { using DRegisters = void; using ARegisters = uint64_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = float[8]; + using CRegisters = float[128]; CUTE_HOST_DEVICE static void fma(uint64_t const& desc_a, uint64_t const& desc_b, - float & d0, float & d1, float & d2, float & d3, - float & d4, float & d5, float & d6, float & d7, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + float & d120, float & d121, float & d122, float & d123, + float & d124, float & d125, float & d126, float & d127, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) @@ -2421,48 +2048,123 @@ struct MMA_64x16x16_F32F16F16_SS asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %10, 0;\n" - "wgmma.mma_async.sync.aligned.m64n16k16.f32.f16.f16 " - "{%0, %1, %2, %3, %4, %5, %6, %7}," - " %8," - " %9," - " p, %11, %12, %13, %14;\n" + "setp.ne.b32 p, %130, 0;\n" + "wgmma.mma_async.sync.aligned.m64n256k16.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + " %128," + " %129," + " p, %131, %132, %133, %134;\n" "}\n" - : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3), - "+f"(d4), "+f"(d5), "+f"(d6), "+f"(d7) + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), + "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123), + "+f"(d124), "+f"(d125), "+f"(d126), "+f"(d127) : "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x16x16_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x256x16_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// GMMA 64x16x16 F32+=F16*F16 +// GMMA 64x256x16 F32+=F16*F16 template < GMMA::Major tnspA, GMMA::Major tnspB, GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct MMA_64x16x16_F32F16F16_RS +struct MMA_64x256x16_F32F16F16_RS { using DRegisters = void; using ARegisters = uint32_t[4]; using BRegisters = uint64_t[1]; - using CRegisters = float[8]; + using CRegisters = float[128]; static_assert(tnspA == GMMA::Major::K, "Register source operand A must have K major layout."); CUTE_HOST_DEVICE static void - fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, uint64_t const& desc_b, - float & d0, float & d1, float & d2, float & d3, - float & d4, float & d5, float & d6, float & d7, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + float & d120, float & d121, float & d122, float & d123, + float & d124, float & d125, float & d126, float & d127, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) @@ -2470,47 +2172,89 @@ struct MMA_64x16x16_F32F16F16_RS asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %13, 0;\n" - "wgmma.mma_async.sync.aligned.m64n16k16.f32.f16.f16 " - "{%0, %1, %2, %3, %4, %5, %6, %7}," - "{%8, %9, %10, %11}," - " %12," - " p, %14, %15, %16;\n" + "setp.ne.b32 p, %133, 0;\n" + "wgmma.mma_async.sync.aligned.m64n256k16.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + "{%128, %129, %130, %131}," + " %132," + " p, %134, %135, %136;\n" "}\n" - : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3), - "+f"(d4), "+f"(d5), "+f"(d6), "+f"(d7) - : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), + "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123), + "+f"(d124), "+f"(d125), "+f"(d126), "+f"(d127) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x16x16_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x256x16_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// GMMA 64x32x16 F32+=F16*F16 +// GMMA 64x8x16 F32+=BF16*BF16 template < GMMA::Major tnspA, GMMA::Major tnspB, GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct MMA_64x32x16_F32F16F16_SS +struct MMA_64x8x16_F32BF16BF16_SS { using DRegisters = void; using ARegisters = uint64_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = float[16]; + using CRegisters = float[4]; CUTE_HOST_DEVICE static void fma(uint64_t const& desc_a, uint64_t const& desc_b, - float & d00, float & d01, float & d02, float & d03, - float & d04, float & d05, float & d06, float & d07, - float & d08, float & d09, float & d10, float & d11, - float & d12, float & d13, float & d14, float & d15, + float & d0, float & d1, float & d2, float & d3, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) @@ -2518,53 +2262,46 @@ struct MMA_64x32x16_F32F16F16_SS asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %18, 0;\n" - "wgmma.mma_async.sync.aligned.m64n32k16.f32.f16.f16 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15}," - " %16," - " %17," - " p, %19, %20, %21, %22;\n" + "setp.ne.b32 p, %6, 0;\n" + "wgmma.mma_async.sync.aligned.m64n8k16.f32.bf16.bf16 " + "{%0, %1, %2, %3}," + " %4," + " %5," + " p, %7, %8, %9, %10;\n" "}\n" - : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), - "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), - "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), - "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15) + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3) : "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x32x16_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x8x16_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// GMMA 64x32x16 F32+=F16*F16 +// GMMA 64x8x16 F32+=BF16*BF16 template < GMMA::Major tnspA, GMMA::Major tnspB, GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct MMA_64x32x16_F32F16F16_RS +struct MMA_64x8x16_F32BF16BF16_RS { using DRegisters = void; using ARegisters = uint32_t[4]; using BRegisters = uint64_t[1]; - using CRegisters = float[16]; + using CRegisters = float[4]; static_assert(tnspA == GMMA::Major::K, "Register source operand A must have K major layout."); CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, uint64_t const& desc_b, - float & d00, float & d01, float & d02, float & d03, - float & d04, float & d05, float & d06, float & d07, - float & d08, float & d09, float & d10, float & d11, - float & d12, float & d13, float & d14, float & d15, + float & d0, float & d1, float & d2, float & d3, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) @@ -2572,108 +2309,44 @@ struct MMA_64x32x16_F32F16F16_RS asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %21, 0;\n" - "wgmma.mma_async.sync.aligned.m64n32k16.f32.f16.f16 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15}," - "{%16, %17, %18, %19}," - " %20," - " p, %22, %23, %24;\n" + "setp.ne.b32 p, %9, 0;\n" + "wgmma.mma_async.sync.aligned.m64n8k16.f32.bf16.bf16 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + " %8," + " p, %10, %11, %12;\n" "}\n" - : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), - "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), - "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), - "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x32x16_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x40x16 F32+=F16*F16 -template < - GMMA::Major tnspA, - GMMA::Major tnspB, - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -struct MMA_64x40x16_F32F16F16_SS -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = float[20]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - float & d00, float & d01, float & d02, float & d03, - float & d04, float & d05, float & d06, float & d07, - float & d08, float & d09, float & d10, float & d11, - float & d12, float & d13, float & d14, float & d15, - float & d16, float & d17, float & d18, float & d19, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %22, 0;\n" - "wgmma.mma_async.sync.aligned.m64n40k16.f32.f16.f16 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19}," - " %20," - " %21," - " p, %23, %24, %25, %26;\n" - "}\n" - : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), - "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), - "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), - "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), - "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19) - : "l"(desc_a), - "l"(desc_b), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x40x16_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x8x16_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x48x16 F32+=F16*F16 +// GMMA 64x16x16 F32+=BF16*BF16 template < GMMA::Major tnspA, GMMA::Major tnspB, GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct MMA_64x48x16_F32F16F16_SS +struct MMA_64x16x16_F32BF16BF16_SS { using DRegisters = void; using ARegisters = uint64_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = float[24]; + using CRegisters = float[8]; CUTE_HOST_DEVICE static void fma(uint64_t const& desc_a, uint64_t const& desc_b, - float & d00, float & d01, float & d02, float & d03, - float & d04, float & d05, float & d06, float & d07, - float & d08, float & d09, float & d10, float & d11, - float & d12, float & d13, float & d14, float & d15, - float & d16, float & d17, float & d18, float & d19, - float & d20, float & d21, float & d22, float & d23, + float & d0, float & d1, float & d2, float & d3, + float & d4, float & d5, float & d6, float & d7, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) @@ -2681,60 +2354,48 @@ struct MMA_64x48x16_F32F16F16_SS asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %26, 0;\n" - "wgmma.mma_async.sync.aligned.m64n48k16.f32.f16.f16 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23}," - " %24," - " %25," - " p, %27, %28, %29, %30;\n" + "setp.ne.b32 p, %10, 0;\n" + "wgmma.mma_async.sync.aligned.m64n16k16.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + " %8," + " %9," + " p, %11, %12, %13, %14;\n" "}\n" - : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), - "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), - "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), - "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), - "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), - "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23) + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3), + "+f"(d4), "+f"(d5), "+f"(d6), "+f"(d7) : "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x48x16_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x16x16_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x48x16 F32+=F16*F16 +// GMMA 64x16x16 F32+=BF16*BF16 template < GMMA::Major tnspA, GMMA::Major tnspB, GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct MMA_64x48x16_F32F16F16_RS +struct MMA_64x16x16_F32BF16BF16_RS { using DRegisters = void; using ARegisters = uint32_t[4]; using BRegisters = uint64_t[1]; - using CRegisters = float[24]; + using CRegisters = float[8]; static_assert(tnspA == GMMA::Major::K, "Register source operand A must have K major layout."); CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, uint64_t const& desc_b, - float & d00, float & d01, float & d02, float & d03, - float & d04, float & d05, float & d06, float & d07, - float & d08, float & d09, float & d10, float & d11, - float & d12, float & d13, float & d14, float & d15, - float & d16, float & d17, float & d18, float & d19, - float & d20, float & d21, float & d22, float & d23, + float & d0, float & d1, float & d2, float & d3, + float & d4, float & d5, float & d6, float & d7, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) @@ -2742,46 +2403,39 @@ struct MMA_64x48x16_F32F16F16_RS asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %29, 0;\n" - "wgmma.mma_async.sync.aligned.m64n48k16.f32.f16.f16 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23}," - "{%24, %25, %26, %27}," - " %28," - " p, %30, %31, %32;\n" + "setp.ne.b32 p, %13, 0;\n" + "wgmma.mma_async.sync.aligned.m64n16k16.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + "{%8, %9, %10, %11}," + " %12," + " p, %14, %15, %16;\n" "}\n" - : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), - "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), - "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), - "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), - "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), - "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3), + "+f"(d4), "+f"(d5), "+f"(d6), "+f"(d7) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x48x16_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x16x16_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -// GMMA 64x64x16 F32+=F16*F16 +// GMMA 64x32x16 F32+=BF16*BF16 template < GMMA::Major tnspA, GMMA::Major tnspB, GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct MMA_64x64x16_F32F16F16_SS +struct MMA_64x32x16_F32BF16BF16_SS { using DRegisters = void; using ARegisters = uint64_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = float[32]; + using CRegisters = float[16]; CUTE_HOST_DEVICE static void fma(uint64_t const& desc_a, @@ -2790,10 +2444,6 @@ struct MMA_64x64x16_F32F16F16_SS float & d04, float & d05, float & d06, float & d07, float & d08, float & d09, float & d10, float & d11, float & d12, float & d13, float & d14, float & d15, - float & d16, float & d17, float & d18, float & d19, - float & d20, float & d21, float & d22, float & d23, - float & d24, float & d25, float & d26, float & d27, - float & d28, float & d29, float & d30, float & d31, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) @@ -2801,48 +2451,42 @@ struct MMA_64x64x16_F32F16F16_SS asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %34, 0;\n" - "wgmma.mma_async.sync.aligned.m64n64k16.f32.f16.f16 " + "setp.ne.b32 p, %18, 0;\n" + "wgmma.mma_async.sync.aligned.m64n32k16.f32.bf16.bf16 " "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31}," - " %32," - " %33," - " p, %35, %36, %37, %38;\n" + " %8, %9, %10, %11, %12, %13, %14, %15}," + " %16," + " %17," + " p, %19, %20, %21, %22;\n" "}\n" : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), - "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), - "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), - "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), - "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), - "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31) + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15) : "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x64x16_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x32x16_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// GMMA 64x64x16 F32+=F16*F16 +// GMMA 64x32x16 F32+=BF16*BF16 template < GMMA::Major tnspA, GMMA::Major tnspB, GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct MMA_64x64x16_F32F16F16_RS +struct MMA_64x32x16_F32BF16BF16_RS { using DRegisters = void; using ARegisters = uint32_t[4]; using BRegisters = uint64_t[1]; - using CRegisters = float[32]; + using CRegisters = float[16]; static_assert(tnspA == GMMA::Major::K, "Register source operand A must have K major layout."); @@ -2854,10 +2498,6 @@ struct MMA_64x64x16_F32F16F16_RS float & d04, float & d05, float & d06, float & d07, float & d08, float & d09, float & d10, float & d11, float & d12, float & d13, float & d14, float & d15, - float & d16, float & d17, float & d18, float & d19, - float & d20, float & d21, float & d22, float & d23, - float & d24, float & d25, float & d26, float & d27, - float & d28, float & d29, float & d30, float & d31, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) @@ -2865,49 +2505,42 @@ struct MMA_64x64x16_F32F16F16_RS asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %37, 0;\n" - "wgmma.mma_async.sync.aligned.m64n64k16.f32.f16.f16 " + "setp.ne.b32 p, %21, 0;\n" + "wgmma.mma_async.sync.aligned.m64n32k16.f32.bf16.bf16 " "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31}," - "{%32, %33, %34, %35}," - " %36," - " p, %38, %39, %40;\n" + " %8, %9, %10, %11, %12, %13, %14, %15}," + "{%16, %17, %18, %19}," + " %20," + " p, %22, %23, %24;\n" "}\n" : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), - "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), - "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), - "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), - "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), - "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31) + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15) : "r"(a00), "r"(a01), "r"(a02), "r"(a03), "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x64x16_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x32x16_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x80x16 F32+=F16*F16 +// GMMA 64x64x16 F32+=BF16*BF16 template < GMMA::Major tnspA, GMMA::Major tnspB, GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct MMA_64x80x16_F32F16F16_SS +struct MMA_64x64x16_F32BF16BF16_SS { using DRegisters = void; using ARegisters = uint64_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = float[40]; + using CRegisters = float[32]; CUTE_HOST_DEVICE static void fma(uint64_t const& desc_a, @@ -2920,8 +2553,6 @@ struct MMA_64x80x16_F32F16F16_SS float & d20, float & d21, float & d22, float & d23, float & d24, float & d25, float & d26, float & d27, float & d28, float & d29, float & d30, float & d31, - float & d32, float & d33, float & d34, float & d35, - float & d36, float & d37, float & d38, float & d39, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) @@ -2929,16 +2560,15 @@ struct MMA_64x80x16_F32F16F16_SS asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %42, 0;\n" - "wgmma.mma_async.sync.aligned.m64n80k16.f32.f16.f16 " + "setp.ne.b32 p, %34, 0;\n" + "wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16 " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39}," - " %40," - " %41," - " p, %43, %44, %45, %46;\n" + " %24, %25, %26, %27, %28, %29, %30, %31}," + " %32," + " %33," + " p, %35, %36, %37, %38;\n" "}\n" : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), @@ -2947,35 +2577,31 @@ struct MMA_64x80x16_F32F16F16_SS "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), - "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), - "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), - "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39) + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31) : "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x80x16_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x64x16_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x80x16 F32+=F16*F16 +// GMMA 64x64x16 F32+=BF16*BF16 template < GMMA::Major tnspA, GMMA::Major tnspB, GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct MMA_64x80x16_F32F16F16_RS +struct MMA_64x64x16_F32BF16BF16_RS { using DRegisters = void; using ARegisters = uint32_t[4]; using BRegisters = uint64_t[1]; - using CRegisters = float[40]; + using CRegisters = float[32]; static_assert(tnspA == GMMA::Major::K, "Register source operand A must have K major layout."); @@ -2991,8 +2617,6 @@ struct MMA_64x80x16_F32F16F16_RS float & d20, float & d21, float & d22, float & d23, float & d24, float & d25, float & d26, float & d27, float & d28, float & d29, float & d30, float & d31, - float & d32, float & d33, float & d34, float & d35, - float & d36, float & d37, float & d38, float & d39, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) @@ -3000,16 +2624,15 @@ struct MMA_64x80x16_F32F16F16_RS asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %45, 0;\n" - "wgmma.mma_async.sync.aligned.m64n80k16.f32.f16.f16 " + "setp.ne.b32 p, %37, 0;\n" + "wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16 " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39}," - "{%40, %41, %42, %43}," - " %44," - " p, %46, %47, %48;\n" + " %24, %25, %26, %27, %28, %29, %30, %31}," + "{%32, %33, %34, %35}," + " %36," + " p, %38, %39, %40;\n" "}\n" : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), @@ -3018,29 +2641,26 @@ struct MMA_64x80x16_F32F16F16_RS "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), - "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), - "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), - "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39) + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31) : "r"(a00), "r"(a01), "r"(a02), "r"(a03), "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x80x16_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x64x16_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -// GMMA 64x96x16 F32+=F16*F16 +// GMMA 64x96x16 F32+=BF16*BF16 template < GMMA::Major tnspA, GMMA::Major tnspB, GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct MMA_64x96x16_F32F16F16_SS +struct MMA_64x96x16_F32BF16BF16_SS { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -3070,7 +2690,7 @@ struct MMA_64x96x16_F32F16F16_SS "{\n" ".reg .pred p;\n" "setp.ne.b32 p, %50, 0;\n" - "wgmma.mma_async.sync.aligned.m64n96k16.f32.f16.f16 " + "wgmma.mma_async.sync.aligned.m64n96k16.f32.bf16.bf16 " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " " %16, %17, %18, %19, %20, %21, %22, %23, " @@ -3097,21 +2717,21 @@ struct MMA_64x96x16_F32F16F16_SS "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x96x16_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x96x16_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// GMMA 64x96x16 F32+=F16*F16 +// GMMA 64x96x16 F32+=BF16*BF16 template < GMMA::Major tnspA, GMMA::Major tnspB, GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct MMA_64x96x16_F32F16F16_RS +struct MMA_64x96x16_F32BF16BF16_RS { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -3144,7 +2764,7 @@ struct MMA_64x96x16_F32F16F16_RS "{\n" ".reg .pred p;\n" "setp.ne.b32 p, %53, 0;\n" - "wgmma.mma_async.sync.aligned.m64n96k16.f32.f16.f16 " + "wgmma.mma_async.sync.aligned.m64n96k16.f32.bf16.bf16 " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " " %16, %17, %18, %19, %20, %21, %22, %23, " @@ -3171,27 +2791,26 @@ struct MMA_64x96x16_F32F16F16_RS "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x96x16_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x96x16_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x112x16 F32+=F16*F16 +// GMMA 64x128x16 F32+=BF16*BF16 template < GMMA::Major tnspA, GMMA::Major tnspB, GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct MMA_64x112x16_F32F16F16_SS +struct MMA_64x128x16_F32BF16BF16_SS { using DRegisters = void; using ARegisters = uint64_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = float[56]; + using CRegisters = float[64]; CUTE_HOST_DEVICE static void fma(uint64_t const& desc_a, @@ -3210,6 +2829,8 @@ struct MMA_64x112x16_F32F16F16_SS float & d44, float & d45, float & d46, float & d47, float & d48, float & d49, float & d50, float & d51, float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) @@ -3217,18 +2838,19 @@ struct MMA_64x112x16_F32F16F16_SS asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %58, 0;\n" - "wgmma.mma_async.sync.aligned.m64n112k16.f32.f16.f16 " + "setp.ne.b32 p, %66, 0;\n" + "wgmma.mma_async.sync.aligned.m64n128k16.f32.bf16.bf16 " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " " %16, %17, %18, %19, %20, %21, %22, %23, " " %24, %25, %26, %27, %28, %29, %30, %31, " " %32, %33, %34, %35, %36, %37, %38, %39, " " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55}," - " %56," - " %57," - " p, %59, %60, %61, %62;\n" + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + " %64," + " %65," + " p, %67, %68, %69, %70;\n" "}\n" : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), @@ -3243,33 +2865,33 @@ struct MMA_64x112x16_F32F16F16_SS "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), - "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55) + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63) : "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x112x16_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x128x16_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x112x16 F32+=F16*F16 +// GMMA 64x128x16 F32+=BF16*BF16 template < GMMA::Major tnspA, GMMA::Major tnspB, GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct MMA_64x112x16_F32F16F16_RS +struct MMA_64x128x16_F32BF16BF16_RS { using DRegisters = void; using ARegisters = uint32_t[4]; using BRegisters = uint64_t[1]; - using CRegisters = float[56]; + using CRegisters = float[64]; static_assert(tnspA == GMMA::Major::K, "Register source operand A must have K major layout."); @@ -3291,6 +2913,8 @@ struct MMA_64x112x16_F32F16F16_RS float & d44, float & d45, float & d46, float & d47, float & d48, float & d49, float & d50, float & d51, float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) @@ -3298,18 +2922,19 @@ struct MMA_64x112x16_F32F16F16_RS asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %61, 0;\n" - "wgmma.mma_async.sync.aligned.m64n112k16.f32.f16.f16 " + "setp.ne.b32 p, %69, 0;\n" + "wgmma.mma_async.sync.aligned.m64n128k16.f32.bf16.bf16 " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " " %16, %17, %18, %19, %20, %21, %22, %23, " " %24, %25, %26, %27, %28, %29, %30, %31, " " %32, %33, %34, %35, %36, %37, %38, %39, " " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55}," - "{%56, %57, %58, %59}," - " %60," - " p, %62, %63, %64;\n" + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + "{%64, %65, %66, %67}," + " %68," + " p, %70, %71, %72;\n" "}\n" : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), @@ -3324,32 +2949,33 @@ struct MMA_64x112x16_F32F16F16_RS "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), - "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55) + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63) : "r"(a00), "r"(a01), "r"(a02), "r"(a03), "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x112x16_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x128x16_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -// GMMA 64x128x16 F32+=F16*F16 +// GMMA 64x192x16 F32+=BF16*BF16 template < GMMA::Major tnspA, GMMA::Major tnspB, GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct MMA_64x128x16_F32F16F16_SS +struct MMA_64x192x16_F32BF16BF16_SS { using DRegisters = void; using ARegisters = uint64_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = float[64]; + using CRegisters = float[96]; CUTE_HOST_DEVICE static void fma(uint64_t const& desc_a, @@ -3370,6 +2996,14 @@ struct MMA_64x128x16_F32F16F16_SS float & d52, float & d53, float & d54, float & d55, float & d56, float & d57, float & d58, float & d59, float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + float & d88, float & d89, float & d90, float & d91, + float & d92, float & d93, float & d94, float & d95, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) @@ -3377,8 +3011,8 @@ struct MMA_64x128x16_F32F16F16_SS asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %66, 0;\n" - "wgmma.mma_async.sync.aligned.m64n128k16.f32.f16.f16 " + "setp.ne.b32 p, %98, 0;\n" + "wgmma.mma_async.sync.aligned.m64n192k16.f32.bf16.bf16 " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " " %16, %17, %18, %19, %20, %21, %22, %23, " @@ -3386,10 +3020,14 @@ struct MMA_64x128x16_F32F16F16_SS " %32, %33, %34, %35, %36, %37, %38, %39, " " %40, %41, %42, %43, %44, %45, %46, %47, " " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63}," - " %64," - " %65," - " p, %67, %68, %69, %70;\n" + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + " %96," + " %97," + " p, %99, %100, %101, %102;\n" "}\n" : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), @@ -3406,31 +3044,39 @@ struct MMA_64x128x16_F32F16F16_SS "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), - "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63) + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), + "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91), + "+f"(d92), "+f"(d93), "+f"(d94), "+f"(d95) : "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x128x16_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x192x16_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// GMMA 64x128x16 F32+=F16*F16 +// GMMA 64x192x16 F32+=BF16*BF16 template < GMMA::Major tnspA, GMMA::Major tnspB, GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct MMA_64x128x16_F32F16F16_RS +struct MMA_64x192x16_F32BF16BF16_RS { using DRegisters = void; using ARegisters = uint32_t[4]; using BRegisters = uint64_t[1]; - using CRegisters = float[64]; + using CRegisters = float[96]; static_assert(tnspA == GMMA::Major::K, "Register source operand A must have K major layout."); @@ -3454,6 +3100,14 @@ struct MMA_64x128x16_F32F16F16_RS float & d52, float & d53, float & d54, float & d55, float & d56, float & d57, float & d58, float & d59, float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + float & d88, float & d89, float & d90, float & d91, + float & d92, float & d93, float & d94, float & d95, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) @@ -3461,8 +3115,8 @@ struct MMA_64x128x16_F32F16F16_RS asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %69, 0;\n" - "wgmma.mma_async.sync.aligned.m64n128k16.f32.f16.f16 " + "setp.ne.b32 p, %101, 0;\n" + "wgmma.mma_async.sync.aligned.m64n192k16.f32.bf16.bf16 " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " " %16, %17, %18, %19, %20, %21, %22, %23, " @@ -3470,10 +3124,14 @@ struct MMA_64x128x16_F32F16F16_RS " %32, %33, %34, %35, %36, %37, %38, %39, " " %40, %41, %42, %43, %44, %45, %46, %47, " " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63}," - "{%64, %65, %66, %67}," - " %68," - " p, %70, %71, %72;\n" + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + "{%96, %97, %98, %99}," + " %100," + " p, %102, %103, %104;\n" "}\n" : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), @@ -3490,54 +3148,75 @@ struct MMA_64x128x16_F32F16F16_RS "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), - "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63) + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), + "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91), + "+f"(d92), "+f"(d93), "+f"(d94), "+f"(d95) : "r"(a00), "r"(a01), "r"(a02), "r"(a03), "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x128x16_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x192x16_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x144x16 F32+=F16*F16 +// GMMA 64x256x16 F32+=BF16*BF16 template < GMMA::Major tnspA, GMMA::Major tnspB, GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct MMA_64x144x16_F32F16F16_SS +struct MMA_64x256x16_F32BF16BF16_SS { using DRegisters = void; using ARegisters = uint64_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = float[72]; + using CRegisters = float[128]; CUTE_HOST_DEVICE static void fma(uint64_t const& desc_a, uint64_t const& desc_b, - float & d00, float & d01, float & d02, float & d03, - float & d04, float & d05, float & d06, float & d07, - float & d08, float & d09, float & d10, float & d11, - float & d12, float & d13, float & d14, float & d15, - float & d16, float & d17, float & d18, float & d19, - float & d20, float & d21, float & d22, float & d23, - float & d24, float & d25, float & d26, float & d27, - float & d28, float & d29, float & d30, float & d31, - float & d32, float & d33, float & d34, float & d35, - float & d36, float & d37, float & d38, float & d39, - float & d40, float & d41, float & d42, float & d43, - float & d44, float & d45, float & d46, float & d47, - float & d48, float & d49, float & d50, float & d51, - float & d52, float & d53, float & d54, float & d55, - float & d56, float & d57, float & d58, float & d59, - float & d60, float & d61, float & d62, float & d63, - float & d64, float & d65, float & d66, float & d67, - float & d68, float & d69, float & d70, float & d71, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + float & d120, float & d121, float & d122, float & d123, + float & d124, float & d125, float & d126, float & d127, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) @@ -3545,8 +3224,8 @@ struct MMA_64x144x16_F32F16F16_SS asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %74, 0;\n" - "wgmma.mma_async.sync.aligned.m64n144k16.f32.f16.f16 " + "setp.ne.b32 p, %130, 0;\n" + "wgmma.mma_async.sync.aligned.m64n256k16.f32.bf16.bf16 " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " " %16, %17, %18, %19, %20, %21, %22, %23, " @@ -3555,80 +3234,113 @@ struct MMA_64x144x16_F32F16F16_SS " %40, %41, %42, %43, %44, %45, %46, %47, " " %48, %49, %50, %51, %52, %53, %54, %55, " " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71}," - " %72," - " %73," - " p, %75, %76, %77, %78;\n" + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + " %128," + " %129," + " p, %131, %132, %133, %134;\n" "}\n" - : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), - "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), - "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), - "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), - "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), - "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), - "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), - "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), - "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), - "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), - "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), - "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), - "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), - "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), - "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), - "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), - "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), - "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71) + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), + "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123), + "+f"(d124), "+f"(d125), "+f"(d126), "+f"(d127) : "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x144x16_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x256x16_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x144x16 F32+=F16*F16 +// GMMA 64x256x16 F32+=BF16*BF16 template < GMMA::Major tnspA, GMMA::Major tnspB, GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct MMA_64x144x16_F32F16F16_RS +struct MMA_64x256x16_F32BF16BF16_RS { using DRegisters = void; using ARegisters = uint32_t[4]; using BRegisters = uint64_t[1]; - using CRegisters = float[72]; + using CRegisters = float[128]; static_assert(tnspA == GMMA::Major::K, "Register source operand A must have K major layout."); CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, uint64_t const& desc_b, - float & d00, float & d01, float & d02, float & d03, - float & d04, float & d05, float & d06, float & d07, - float & d08, float & d09, float & d10, float & d11, - float & d12, float & d13, float & d14, float & d15, - float & d16, float & d17, float & d18, float & d19, - float & d20, float & d21, float & d22, float & d23, - float & d24, float & d25, float & d26, float & d27, - float & d28, float & d29, float & d30, float & d31, - float & d32, float & d33, float & d34, float & d35, - float & d36, float & d37, float & d38, float & d39, - float & d40, float & d41, float & d42, float & d43, - float & d44, float & d45, float & d46, float & d47, - float & d48, float & d49, float & d50, float & d51, - float & d52, float & d53, float & d54, float & d55, - float & d56, float & d57, float & d58, float & d59, - float & d60, float & d61, float & d62, float & d63, - float & d64, float & d65, float & d66, float & d67, - float & d68, float & d69, float & d70, float & d71, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + float & d120, float & d121, float & d122, float & d123, + float & d124, float & d125, float & d126, float & d127, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) @@ -3636,8 +3348,8 @@ struct MMA_64x144x16_F32F16F16_RS asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %77, 0;\n" - "wgmma.mma_async.sync.aligned.m64n144k16.f32.f16.f16 " + "setp.ne.b32 p, %133, 0;\n" + "wgmma.mma_async.sync.aligned.m64n256k16.f32.bf16.bf16 " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " " %16, %17, %18, %19, %20, %21, %22, %23, " @@ -3646,79 +3358,77 @@ struct MMA_64x144x16_F32F16F16_RS " %40, %41, %42, %43, %44, %45, %46, %47, " " %48, %49, %50, %51, %52, %53, %54, %55, " " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71}," - "{%72, %73, %74, %75}," - " %76," - " p, %78, %79, %80;\n" + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + "{%128, %129, %130, %131}," + " %132," + " p, %134, %135, %136;\n" "}\n" - : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), - "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), - "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), - "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), - "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), - "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), - "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), - "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), - "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), - "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), - "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), - "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), - "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), - "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), - "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), - "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), - "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), - "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), + "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123), + "+f"(d124), "+f"(d125), "+f"(d126), "+f"(d127) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x144x16_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x256x16_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x160x16 F32+=F16*F16 +// GMMA 64x8x8 TN F32+=TF32*TF32 template < - GMMA::Major tnspA, - GMMA::Major tnspB, GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct MMA_64x160x16_F32F16F16_SS +struct MMA_64x8x8_F32TF32TF32_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = float[80]; + using CRegisters = float[4]; CUTE_HOST_DEVICE static void fma(uint64_t const& desc_a, uint64_t const& desc_b, - float & d00, float & d01, float & d02, float & d03, - float & d04, float & d05, float & d06, float & d07, - float & d08, float & d09, float & d10, float & d11, - float & d12, float & d13, float & d14, float & d15, - float & d16, float & d17, float & d18, float & d19, - float & d20, float & d21, float & d22, float & d23, - float & d24, float & d25, float & d26, float & d27, - float & d28, float & d29, float & d30, float & d31, - float & d32, float & d33, float & d34, float & d35, - float & d36, float & d37, float & d38, float & d39, - float & d40, float & d41, float & d42, float & d43, - float & d44, float & d45, float & d46, float & d47, - float & d48, float & d49, float & d50, float & d51, - float & d52, float & d53, float & d54, float & d55, - float & d56, float & d57, float & d58, float & d59, - float & d60, float & d61, float & d62, float & d63, - float & d64, float & d65, float & d66, float & d67, - float & d68, float & d69, float & d70, float & d71, - float & d72, float & d73, float & d74, float & d75, - float & d76, float & d77, float & d78, float & d79, + float & d0, float & d1, float & d2, float & d3, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) @@ -3726,95 +3436,41 @@ struct MMA_64x160x16_F32F16F16_SS asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %82, 0;\n" - "wgmma.mma_async.sync.aligned.m64n160k16.f32.f16.f16 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79}," - " %80," - " %81," - " p, %83, %84, %85, %86;\n" + "setp.ne.b32 p, %6, 0;\n" + "wgmma.mma_async.sync.aligned.m64n8k8.f32.tf32.tf32 " + "{%0, %1, %2, %3}," + " %4," + " %5," + " p, %7, %8;\n" "}\n" - : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), - "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), - "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), - "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), - "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), - "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), - "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), - "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), - "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), - "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), - "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), - "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), - "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), - "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), - "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), - "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), - "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), - "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), - "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), - "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79) + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3) : "l"(desc_a), "l"(desc_b), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x160x16_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x8x8_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x160x16 F32+=F16*F16 +// GMMA 64x8x8 TN F32+=TF32*TF32 template < - GMMA::Major tnspA, - GMMA::Major tnspB, GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct MMA_64x160x16_F32F16F16_RS +struct MMA_64x8x8_F32TF32TF32_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; using BRegisters = uint64_t[1]; - using CRegisters = float[80]; - - static_assert(tnspA == GMMA::Major::K, - "Register source operand A must have K major layout."); + using CRegisters = float[4]; CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, uint64_t const& desc_b, - float & d00, float & d01, float & d02, float & d03, - float & d04, float & d05, float & d06, float & d07, - float & d08, float & d09, float & d10, float & d11, - float & d12, float & d13, float & d14, float & d15, - float & d16, float & d17, float & d18, float & d19, - float & d20, float & d21, float & d22, float & d23, - float & d24, float & d25, float & d26, float & d27, - float & d28, float & d29, float & d30, float & d31, - float & d32, float & d33, float & d34, float & d35, - float & d36, float & d37, float & d38, float & d39, - float & d40, float & d41, float & d42, float & d43, - float & d44, float & d45, float & d46, float & d47, - float & d48, float & d49, float & d50, float & d51, - float & d52, float & d53, float & d54, float & d55, - float & d56, float & d57, float & d58, float & d59, - float & d60, float & d61, float & d62, float & d63, - float & d64, float & d65, float & d66, float & d67, - float & d68, float & d69, float & d70, float & d71, - float & d72, float & d73, float & d74, float & d75, - float & d76, float & d77, float & d78, float & d79, + float & d0, float & d1, float & d2, float & d3, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) @@ -3822,94 +3478,42 @@ struct MMA_64x160x16_F32F16F16_RS asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %85, 0;\n" - "wgmma.mma_async.sync.aligned.m64n160k16.f32.f16.f16 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79}," - "{%80, %81, %82, %83}," - " %84," - " p, %86, %87, %88;\n" + "setp.ne.b32 p, %9, 0;\n" + "wgmma.mma_async.sync.aligned.m64n8k8.f32.tf32.tf32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + " %8," + " p, %10, %11;\n" "}\n" - : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), - "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), - "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), - "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), - "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), - "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), - "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), - "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), - "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), - "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), - "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), - "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), - "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), - "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), - "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), - "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), - "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), - "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), - "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), - "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), "l"(desc_b), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x160x16_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x8x8_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x176x16 F32+=F16*F16 +// GMMA 64x16x8 TN F32+=TF32*TF32 template < - GMMA::Major tnspA, - GMMA::Major tnspB, GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct MMA_64x176x16_F32F16F16_SS +struct MMA_64x16x8_F32TF32TF32_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = float[88]; + using CRegisters = float[8]; CUTE_HOST_DEVICE static void fma(uint64_t const& desc_a, uint64_t const& desc_b, - float & d00, float & d01, float & d02, float & d03, - float & d04, float & d05, float & d06, float & d07, - float & d08, float & d09, float & d10, float & d11, - float & d12, float & d13, float & d14, float & d15, - float & d16, float & d17, float & d18, float & d19, - float & d20, float & d21, float & d22, float & d23, - float & d24, float & d25, float & d26, float & d27, - float & d28, float & d29, float & d30, float & d31, - float & d32, float & d33, float & d34, float & d35, - float & d36, float & d37, float & d38, float & d39, - float & d40, float & d41, float & d42, float & d43, - float & d44, float & d45, float & d46, float & d47, - float & d48, float & d49, float & d50, float & d51, - float & d52, float & d53, float & d54, float & d55, - float & d56, float & d57, float & d58, float & d59, - float & d60, float & d61, float & d62, float & d63, - float & d64, float & d65, float & d66, float & d67, - float & d68, float & d69, float & d70, float & d71, - float & d72, float & d73, float & d74, float & d75, - float & d76, float & d77, float & d78, float & d79, - float & d80, float & d81, float & d82, float & d83, - float & d84, float & d85, float & d86, float & d87, + float & d0, float & d1, float & d2, float & d3, + float & d4, float & d5, float & d6, float & d7, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) @@ -3917,22 +3521,331 @@ struct MMA_64x176x16_F32F16F16_SS asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %90, 0;\n" - "wgmma.mma_async.sync.aligned.m64n176k16.f32.f16.f16 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87}," - " %88," - " %89," - " p, %91, %92, %93, %94;\n" + "setp.ne.b32 p, %10, 0;\n" + "wgmma.mma_async.sync.aligned.m64n16k8.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + " %8," + " %9," + " p, %11, %12;\n" + "}\n" + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3), + "+f"(d4), "+f"(d5), "+f"(d6), "+f"(d7) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x16x8_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x16x8 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x16x8_F32TF32TF32_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[8]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + float & d0, float & d1, float & d2, float & d3, + float & d4, float & d5, float & d6, float & d7, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %13, 0;\n" + "wgmma.mma_async.sync.aligned.m64n16k8.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + "{%8, %9, %10, %11}," + " %12," + " p, %14, %15;\n" + "}\n" + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3), + "+f"(d4), "+f"(d5), "+f"(d6), "+f"(d7) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x16x8_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x32x8 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x32x8_F32TF32TF32_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[16]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %18, 0;\n" + "wgmma.mma_async.sync.aligned.m64n32k8.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + " %16," + " %17," + " p, %19, %20;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x32x8_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x32x8 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x32x8_F32TF32TF32_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[16]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %21, 0;\n" + "wgmma.mma_async.sync.aligned.m64n32k8.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + "{%16, %17, %18, %19}," + " %20," + " p, %22, %23;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x32x8_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x64x8 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x64x8_F32TF32TF32_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[32]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %34, 0;\n" + "wgmma.mma_async.sync.aligned.m64n64k8.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + " %32," + " %33," + " p, %35, %36;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x64x8_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x64x8 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x64x8_F32TF32TF32_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[32]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %37, 0;\n" + "wgmma.mma_async.sync.aligned.m64n64k8.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + "{%32, %33, %34, %35}," + " %36," + " p, %38, %39;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x64x8_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x96x8 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x96x8_F32TF32TF32_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[48]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %50, 0;\n" + "wgmma.mma_async.sync.aligned.m64n96k8.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + " %48," + " %49," + " p, %51, %52;\n" "}\n" : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), @@ -3945,46 +3858,29 @@ struct MMA_64x176x16_F32F16F16_SS "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), - "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), - "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), - "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), - "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), - "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), - "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), - "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), - "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), - "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), - "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), - "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87) + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47) : "l"(desc_a), "l"(desc_b), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x176x16_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x96x8_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x176x16 F32+=F16*F16 +// GMMA 64x96x8 TN F32+=TF32*TF32 template < - GMMA::Major tnspA, - GMMA::Major tnspB, GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct MMA_64x176x16_F32F16F16_RS +struct MMA_64x96x8_F32TF32TF32_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; using BRegisters = uint64_t[1]; - using CRegisters = float[88]; - - static_assert(tnspA == GMMA::Major::K, - "Register source operand A must have K major layout."); + using CRegisters = float[48]; CUTE_HOST_DEVICE static void fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, @@ -4001,16 +3897,6 @@ struct MMA_64x176x16_F32F16F16_RS float & d36, float & d37, float & d38, float & d39, float & d40, float & d41, float & d42, float & d43, float & d44, float & d45, float & d46, float & d47, - float & d48, float & d49, float & d50, float & d51, - float & d52, float & d53, float & d54, float & d55, - float & d56, float & d57, float & d58, float & d59, - float & d60, float & d61, float & d62, float & d63, - float & d64, float & d65, float & d66, float & d67, - float & d68, float & d69, float & d70, float & d71, - float & d72, float & d73, float & d74, float & d75, - float & d76, float & d77, float & d78, float & d79, - float & d80, float & d81, float & d82, float & d83, - float & d84, float & d85, float & d86, float & d87, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) @@ -4018,22 +3904,17 @@ struct MMA_64x176x16_F32F16F16_RS asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %93, 0;\n" - "wgmma.mma_async.sync.aligned.m64n176k16.f32.f16.f16 " + "setp.ne.b32 p, %53, 0;\n" + "wgmma.mma_async.sync.aligned.m64n96k8.f32.tf32.tf32 " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " " %16, %17, %18, %19, %20, %21, %22, %23, " " %24, %25, %26, %27, %28, %29, %30, %31, " " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87}," - "{%88, %89, %90, %91}," - " %92," - " p, %94, %95, %96;\n" + " %40, %41, %42, %43, %44, %45, %46, %47}," + "{%48, %49, %50, %51}," + " %52," + " p, %54, %55;\n" "}\n" : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), @@ -4046,42 +3927,29 @@ struct MMA_64x176x16_F32F16F16_RS "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), - "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), - "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), - "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), - "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), - "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), - "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), - "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), - "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), - "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), - "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), - "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87) + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47) : "r"(a00), "r"(a01), "r"(a02), "r"(a03), "l"(desc_b), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x176x16_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x96x8_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -// GMMA 64x192x16 F32+=F16*F16 +// GMMA 64x128x8 TN F32+=TF32*TF32 template < - GMMA::Major tnspA, - GMMA::Major tnspB, GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct MMA_64x192x16_F32F16F16_SS +struct MMA_64x128x8_F32TF32TF32_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = float[96]; + using CRegisters = float[64]; CUTE_HOST_DEVICE static void fma(uint64_t const& desc_a, @@ -4102,14 +3970,6 @@ struct MMA_64x192x16_F32F16F16_SS float & d52, float & d53, float & d54, float & d55, float & d56, float & d57, float & d58, float & d59, float & d60, float & d61, float & d62, float & d63, - float & d64, float & d65, float & d66, float & d67, - float & d68, float & d69, float & d70, float & d71, - float & d72, float & d73, float & d74, float & d75, - float & d76, float & d77, float & d78, float & d79, - float & d80, float & d81, float & d82, float & d83, - float & d84, float & d85, float & d86, float & d87, - float & d88, float & d89, float & d90, float & d91, - float & d92, float & d93, float & d94, float & d95, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) @@ -4117,8 +3977,8 @@ struct MMA_64x192x16_F32F16F16_SS asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %98, 0;\n" - "wgmma.mma_async.sync.aligned.m64n192k16.f32.f16.f16 " + "setp.ne.b32 p, %66, 0;\n" + "wgmma.mma_async.sync.aligned.m64n128k8.f32.tf32.tf32 " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " " %16, %17, %18, %19, %20, %21, %22, %23, " @@ -4126,14 +3986,10 @@ struct MMA_64x192x16_F32F16F16_SS " %32, %33, %34, %35, %36, %37, %38, %39, " " %40, %41, %42, %43, %44, %45, %46, %47, " " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87, " - " %88, %89, %90, %91, %92, %93, %94, %95}," - " %96," - " %97," - " p, %99, %100, %101, %102;\n" + " %56, %57, %58, %59, %60, %61, %62, %63}," + " %64," + " %65," + " p, %67, %68;\n" "}\n" : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), @@ -4150,42 +4006,29 @@ struct MMA_64x192x16_F32F16F16_SS "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), - "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), - "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), - "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), - "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), - "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), - "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), - "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), - "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91), - "+f"(d92), "+f"(d93), "+f"(d94), "+f"(d95) + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63) : "l"(desc_a), "l"(desc_b), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x192x16_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x128x8_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// GMMA 64x192x16 F32+=F16*F16 +// GMMA 64x128x8 TN F32+=TF32*TF32 template < - GMMA::Major tnspA, - GMMA::Major tnspB, GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct MMA_64x192x16_F32F16F16_RS +struct MMA_64x128x8_F32TF32TF32_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; using BRegisters = uint64_t[1]; - using CRegisters = float[96]; - - static_assert(tnspA == GMMA::Major::K, - "Register source operand A must have K major layout."); + using CRegisters = float[64]; CUTE_HOST_DEVICE static void fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, @@ -4206,14 +4049,6 @@ struct MMA_64x192x16_F32F16F16_RS float & d52, float & d53, float & d54, float & d55, float & d56, float & d57, float & d58, float & d59, float & d60, float & d61, float & d62, float & d63, - float & d64, float & d65, float & d66, float & d67, - float & d68, float & d69, float & d70, float & d71, - float & d72, float & d73, float & d74, float & d75, - float & d76, float & d77, float & d78, float & d79, - float & d80, float & d81, float & d82, float & d83, - float & d84, float & d85, float & d86, float & d87, - float & d88, float & d89, float & d90, float & d91, - float & d92, float & d93, float & d94, float & d95, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) @@ -4221,8 +4056,8 @@ struct MMA_64x192x16_F32F16F16_RS asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %101, 0;\n" - "wgmma.mma_async.sync.aligned.m64n192k16.f32.f16.f16 " + "setp.ne.b32 p, %69, 0;\n" + "wgmma.mma_async.sync.aligned.m64n128k8.f32.tf32.tf32 " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " " %16, %17, %18, %19, %20, %21, %22, %23, " @@ -4230,14 +4065,10 @@ struct MMA_64x192x16_F32F16F16_RS " %32, %33, %34, %35, %36, %37, %38, %39, " " %40, %41, %42, %43, %44, %45, %46, %47, " " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87, " - " %88, %89, %90, %91, %92, %93, %94, %95}," - "{%96, %97, %98, %99}," - " %100," - " p, %102, %103, %104;\n" + " %56, %57, %58, %59, %60, %61, %62, %63}," + "{%64, %65, %66, %67}," + " %68," + " p, %70, %71;\n" "}\n" : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), @@ -4254,70 +4085,57 @@ struct MMA_64x192x16_F32F16F16_RS "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), - "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), - "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), - "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), - "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), - "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), - "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), - "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), - "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91), - "+f"(d92), "+f"(d93), "+f"(d94), "+f"(d95) + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63) : "r"(a00), "r"(a01), "r"(a02), "r"(a03), "l"(desc_b), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x192x16_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x128x8_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x208x16 F32+=F16*F16 +// GMMA 64x192x8 TN F32+=TF32*TF32 template < - GMMA::Major tnspA, - GMMA::Major tnspB, GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct MMA_64x208x16_F32F16F16_SS +struct MMA_64x192x8_F32TF32TF32_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = float[104]; + using CRegisters = float[96]; CUTE_HOST_DEVICE static void fma(uint64_t const& desc_a, uint64_t const& desc_b, - float & d000, float & d001, float & d002, float & d003, - float & d004, float & d005, float & d006, float & d007, - float & d008, float & d009, float & d010, float & d011, - float & d012, float & d013, float & d014, float & d015, - float & d016, float & d017, float & d018, float & d019, - float & d020, float & d021, float & d022, float & d023, - float & d024, float & d025, float & d026, float & d027, - float & d028, float & d029, float & d030, float & d031, - float & d032, float & d033, float & d034, float & d035, - float & d036, float & d037, float & d038, float & d039, - float & d040, float & d041, float & d042, float & d043, - float & d044, float & d045, float & d046, float & d047, - float & d048, float & d049, float & d050, float & d051, - float & d052, float & d053, float & d054, float & d055, - float & d056, float & d057, float & d058, float & d059, - float & d060, float & d061, float & d062, float & d063, - float & d064, float & d065, float & d066, float & d067, - float & d068, float & d069, float & d070, float & d071, - float & d072, float & d073, float & d074, float & d075, - float & d076, float & d077, float & d078, float & d079, - float & d080, float & d081, float & d082, float & d083, - float & d084, float & d085, float & d086, float & d087, - float & d088, float & d089, float & d090, float & d091, - float & d092, float & d093, float & d094, float & d095, - float & d096, float & d097, float & d098, float & d099, - float & d100, float & d101, float & d102, float & d103, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + float & d88, float & d89, float & d90, float & d91, + float & d92, float & d93, float & d94, float & d95, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) @@ -4325,8 +4143,8 @@ struct MMA_64x208x16_F32F16F16_SS asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %106, 0;\n" - "wgmma.mma_async.sync.aligned.m64n208k16.f32.f16.f16 " + "setp.ne.b32 p, %98, 0;\n" + "wgmma.mma_async.sync.aligned.m64n192k8.f32.tf32.tf32 " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " " %16, %17, %18, %19, %20, %21, %22, %23, " @@ -4338,97 +4156,85 @@ struct MMA_64x208x16_F32F16F16_SS " %64, %65, %66, %67, %68, %69, %70, %71, " " %72, %73, %74, %75, %76, %77, %78, %79, " " %80, %81, %82, %83, %84, %85, %86, %87, " - " %88, %89, %90, %91, %92, %93, %94, %95, " - " %96, %97, %98, %99, %100, %101, %102, %103}," - " %104," - " %105," - " p, %107, %108, %109, %110;\n" + " %88, %89, %90, %91, %92, %93, %94, %95}," + " %96," + " %97," + " p, %99, %100;\n" "}\n" - : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), - "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), - "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), - "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), - "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), - "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), - "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), - "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), - "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), - "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), - "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), - "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), - "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), - "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), - "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), - "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), - "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), - "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), - "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), - "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), - "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), - "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), - "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), - "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), - "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), - "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103) + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), + "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91), + "+f"(d92), "+f"(d93), "+f"(d94), "+f"(d95) : "l"(desc_a), "l"(desc_b), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x208x16_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x192x8_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x208x16 F32+=F16*F16 +// GMMA 64x192x8 TN F32+=TF32*TF32 template < - GMMA::Major tnspA, - GMMA::Major tnspB, GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct MMA_64x208x16_F32F16F16_RS +struct MMA_64x192x8_F32TF32TF32_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; using BRegisters = uint64_t[1]; - using CRegisters = float[104]; - - static_assert(tnspA == GMMA::Major::K, - "Register source operand A must have K major layout."); + using CRegisters = float[96]; CUTE_HOST_DEVICE static void - fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, uint64_t const& desc_b, - float & d000, float & d001, float & d002, float & d003, - float & d004, float & d005, float & d006, float & d007, - float & d008, float & d009, float & d010, float & d011, - float & d012, float & d013, float & d014, float & d015, - float & d016, float & d017, float & d018, float & d019, - float & d020, float & d021, float & d022, float & d023, - float & d024, float & d025, float & d026, float & d027, - float & d028, float & d029, float & d030, float & d031, - float & d032, float & d033, float & d034, float & d035, - float & d036, float & d037, float & d038, float & d039, - float & d040, float & d041, float & d042, float & d043, - float & d044, float & d045, float & d046, float & d047, - float & d048, float & d049, float & d050, float & d051, - float & d052, float & d053, float & d054, float & d055, - float & d056, float & d057, float & d058, float & d059, - float & d060, float & d061, float & d062, float & d063, - float & d064, float & d065, float & d066, float & d067, - float & d068, float & d069, float & d070, float & d071, - float & d072, float & d073, float & d074, float & d075, - float & d076, float & d077, float & d078, float & d079, - float & d080, float & d081, float & d082, float & d083, - float & d084, float & d085, float & d086, float & d087, - float & d088, float & d089, float & d090, float & d091, - float & d092, float & d093, float & d094, float & d095, - float & d096, float & d097, float & d098, float & d099, - float & d100, float & d101, float & d102, float & d103, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + float & d88, float & d89, float & d90, float & d91, + float & d92, float & d93, float & d94, float & d95, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) @@ -4436,8 +4242,8 @@ struct MMA_64x208x16_F32F16F16_RS asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %109, 0;\n" - "wgmma.mma_async.sync.aligned.m64n208k16.f32.f16.f16 " + "setp.ne.b32 p, %101, 0;\n" + "wgmma.mma_async.sync.aligned.m64n192k8.f32.tf32.tf32 " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " " %16, %17, %18, %19, %20, %21, %22, %23, " @@ -4449,64 +4255,57 @@ struct MMA_64x208x16_F32F16F16_RS " %64, %65, %66, %67, %68, %69, %70, %71, " " %72, %73, %74, %75, %76, %77, %78, %79, " " %80, %81, %82, %83, %84, %85, %86, %87, " - " %88, %89, %90, %91, %92, %93, %94, %95, " - " %96, %97, %98, %99, %100, %101, %102, %103}," - "{%104, %105, %106, %107}," - " %108," - " p, %110, %111, %112;\n" + " %88, %89, %90, %91, %92, %93, %94, %95}," + "{%96, %97, %98, %99}," + " %100," + " p, %102, %103;\n" "}\n" - : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), - "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), - "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), - "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), - "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), - "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), - "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), - "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), - "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), - "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), - "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), - "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), - "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), - "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), - "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), - "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), - "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), - "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), - "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), - "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), - "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), - "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), - "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), - "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), - "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), - "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103) - : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), + "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91), + "+f"(d92), "+f"(d93), "+f"(d94), "+f"(d95) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), "l"(desc_b), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x208x16_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x192x8_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x224x16 F32+=F16*F16 +// GMMA 64x256x8 TN F32+=TF32*TF32 template < - GMMA::Major tnspA, - GMMA::Major tnspB, GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct MMA_64x224x16_F32F16F16_SS +struct MMA_64x256x8_F32TF32TF32_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = float[112]; + using CRegisters = float[128]; CUTE_HOST_DEVICE static void fma(uint64_t const& desc_a, @@ -4539,6 +4338,10 @@ struct MMA_64x224x16_F32F16F16_SS float & d100, float & d101, float & d102, float & d103, float & d104, float & d105, float & d106, float & d107, float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + float & d120, float & d121, float & d122, float & d123, + float & d124, float & d125, float & d126, float & d127, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) @@ -4546,8 +4349,8 @@ struct MMA_64x224x16_F32F16F16_SS asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %114, 0;\n" - "wgmma.mma_async.sync.aligned.m64n224k16.f32.f16.f16 " + "setp.ne.b32 p, %130, 0;\n" + "wgmma.mma_async.sync.aligned.m64n256k8.f32.tf32.tf32 " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " " %16, %17, %18, %19, %20, %21, %22, %23, " @@ -4561,10 +4364,12 @@ struct MMA_64x224x16_F32F16F16_SS " %80, %81, %82, %83, %84, %85, %86, %87, " " %88, %89, %90, %91, %92, %93, %94, %95, " " %96, %97, %98, %99, %100, %101, %102, %103, " - " %104, %105, %106, %107, %108, %109, %110, %111}," - " %112," - " %113," - " p, %115, %116, %117, %118;\n" + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + " %128," + " %129," + " p, %131, %132;\n" "}\n" : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), @@ -4593,36 +4398,33 @@ struct MMA_64x224x16_F32F16F16_SS "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), - "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111) + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), + "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123), + "+f"(d124), "+f"(d125), "+f"(d126), "+f"(d127) : "l"(desc_a), "l"(desc_b), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x224x16_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x256x8_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x224x16 F32+=F16*F16 +// GMMA 64x256x8 TN F32+=TF32*TF32 template < - GMMA::Major tnspA, - GMMA::Major tnspB, GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct MMA_64x224x16_F32F16F16_RS +struct MMA_64x256x8_F32TF32TF32_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; using BRegisters = uint64_t[1]; - using CRegisters = float[112]; - - static_assert(tnspA == GMMA::Major::K, - "Register source operand A must have K major layout."); + using CRegisters = float[128]; CUTE_HOST_DEVICE static void fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, @@ -4655,6 +4457,10 @@ struct MMA_64x224x16_F32F16F16_RS float & d100, float & d101, float & d102, float & d103, float & d104, float & d105, float & d106, float & d107, float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + float & d120, float & d121, float & d122, float & d123, + float & d124, float & d125, float & d126, float & d127, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) @@ -4662,8 +4468,8 @@ struct MMA_64x224x16_F32F16F16_RS asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %117, 0;\n" - "wgmma.mma_async.sync.aligned.m64n224k16.f32.f16.f16 " + "setp.ne.b32 p, %133, 0;\n" + "wgmma.mma_async.sync.aligned.m64n256k8.f32.tf32.tf32 " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " " %16, %17, %18, %19, %20, %21, %22, %23, " @@ -4677,10 +4483,12 @@ struct MMA_64x224x16_F32F16F16_RS " %80, %81, %82, %83, %84, %85, %86, %87, " " %88, %89, %90, %91, %92, %93, %94, %95, " " %96, %97, %98, %99, %100, %101, %102, %103, " - " %104, %105, %106, %107, %108, %109, %110, %111}," - "{%112, %113, %114, %115}," - " %116," - " p, %118, %119, %120;\n" + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + "{%128, %129, %130, %131}," + " %132," + " p, %134, %135;\n" "}\n" : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), @@ -4709,67 +4517,34 @@ struct MMA_64x224x16_F32F16F16_RS "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), - "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111) + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), + "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123), + "+f"(d124), "+f"(d125), "+f"(d126), "+f"(d127) : "r"(a000), "r"(a001), "r"(a002), "r"(a003), "l"(desc_b), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x224x16_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x256x8_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x240x16 F32+=F16*F16 -template < - GMMA::Major tnspA, - GMMA::Major tnspB, - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -struct MMA_64x240x16_F32F16F16_SS +// GMMA 64x8x32 TN S32+=S8*S8 +struct MMA_64x8x32_S32S8S8_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = float[120]; + using CRegisters = uint32_t[4]; CUTE_HOST_DEVICE static void fma(uint64_t const& desc_a, uint64_t const& desc_b, - float & d000, float & d001, float & d002, float & d003, - float & d004, float & d005, float & d006, float & d007, - float & d008, float & d009, float & d010, float & d011, - float & d012, float & d013, float & d014, float & d015, - float & d016, float & d017, float & d018, float & d019, - float & d020, float & d021, float & d022, float & d023, - float & d024, float & d025, float & d026, float & d027, - float & d028, float & d029, float & d030, float & d031, - float & d032, float & d033, float & d034, float & d035, - float & d036, float & d037, float & d038, float & d039, - float & d040, float & d041, float & d042, float & d043, - float & d044, float & d045, float & d046, float & d047, - float & d048, float & d049, float & d050, float & d051, - float & d052, float & d053, float & d054, float & d055, - float & d056, float & d057, float & d058, float & d059, - float & d060, float & d061, float & d062, float & d063, - float & d064, float & d065, float & d066, float & d067, - float & d068, float & d069, float & d070, float & d071, - float & d072, float & d073, float & d074, float & d075, - float & d076, float & d077, float & d078, float & d079, - float & d080, float & d081, float & d082, float & d083, - float & d084, float & d085, float & d086, float & d087, - float & d088, float & d089, float & d090, float & d091, - float & d092, float & d093, float & d094, float & d095, - float & d096, float & d097, float & d098, float & d099, - float & d100, float & d101, float & d102, float & d103, - float & d104, float & d105, float & d106, float & d107, - float & d108, float & d109, float & d110, float & d111, - float & d112, float & d113, float & d114, float & d115, - float & d116, float & d117, float & d118, float & d119, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) @@ -4777,239 +4552,76 @@ struct MMA_64x240x16_F32F16F16_SS asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %122, 0;\n" - "wgmma.mma_async.sync.aligned.m64n240k16.f32.f16.f16 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87, " - " %88, %89, %90, %91, %92, %93, %94, %95, " - " %96, %97, %98, %99, %100, %101, %102, %103, " - " %104, %105, %106, %107, %108, %109, %110, %111, " - " %112, %113, %114, %115, %116, %117, %118, %119}," - " %120," - " %121," - " p, %123, %124, %125, %126;\n" + "setp.ne.b32 p, %6, 0;\n" + "wgmma.mma_async.sync.aligned.m64n8k32.s32.s8.s8 " + "{%0, %1, %2, %3}," + " %4," + " %5," + " p;\n" "}\n" - : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), - "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), - "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), - "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), - "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), - "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), - "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), - "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), - "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), - "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), - "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), - "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), - "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), - "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), - "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), - "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), - "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), - "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), - "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), - "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), - "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), - "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), - "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), - "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), - "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), - "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), - "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), - "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), - "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), - "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119) + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) : "l"(desc_a), "l"(desc_b), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); + "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x240x16_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x8x32_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x240x16 F32+=F16*F16 -template < - GMMA::Major tnspA, - GMMA::Major tnspB, - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -struct MMA_64x240x16_F32F16F16_RS +// GMMA 64x8x32 TN S32+=S8*S8 +struct MMA_64x8x32_S32S8S8_SS_TN_SATURATE { using DRegisters = void; - using ARegisters = uint32_t[4]; + using ARegisters = uint64_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = float[120]; - - static_assert(tnspA == GMMA::Major::K, - "Register source operand A must have K major layout."); + using CRegisters = uint32_t[4]; CUTE_HOST_DEVICE static void - fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + fma(uint64_t const& desc_a, uint64_t const& desc_b, - float & d000, float & d001, float & d002, float & d003, - float & d004, float & d005, float & d006, float & d007, - float & d008, float & d009, float & d010, float & d011, - float & d012, float & d013, float & d014, float & d015, - float & d016, float & d017, float & d018, float & d019, - float & d020, float & d021, float & d022, float & d023, - float & d024, float & d025, float & d026, float & d027, - float & d028, float & d029, float & d030, float & d031, - float & d032, float & d033, float & d034, float & d035, - float & d036, float & d037, float & d038, float & d039, - float & d040, float & d041, float & d042, float & d043, - float & d044, float & d045, float & d046, float & d047, - float & d048, float & d049, float & d050, float & d051, - float & d052, float & d053, float & d054, float & d055, - float & d056, float & d057, float & d058, float & d059, - float & d060, float & d061, float & d062, float & d063, - float & d064, float & d065, float & d066, float & d067, - float & d068, float & d069, float & d070, float & d071, - float & d072, float & d073, float & d074, float & d075, - float & d076, float & d077, float & d078, float & d079, - float & d080, float & d081, float & d082, float & d083, - float & d084, float & d085, float & d086, float & d087, - float & d088, float & d089, float & d090, float & d091, - float & d092, float & d093, float & d094, float & d095, - float & d096, float & d097, float & d098, float & d099, - float & d100, float & d101, float & d102, float & d103, - float & d104, float & d105, float & d106, float & d107, - float & d108, float & d109, float & d110, float & d111, - float & d112, float & d113, float & d114, float & d115, - float & d116, float & d117, float & d118, float & d119, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %125, 0;\n" - "wgmma.mma_async.sync.aligned.m64n240k16.f32.f16.f16 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87, " - " %88, %89, %90, %91, %92, %93, %94, %95, " - " %96, %97, %98, %99, %100, %101, %102, %103, " - " %104, %105, %106, %107, %108, %109, %110, %111, " - " %112, %113, %114, %115, %116, %117, %118, %119}," - "{%120, %121, %122, %123}," - " %124," - " p, %126, %127, %128;\n" + "setp.ne.b32 p, %6, 0;\n" + "wgmma.mma_async.sync.aligned.m64n8k32.s32.s8.s8.satfinite " + "{%0, %1, %2, %3}," + " %4," + " %5," + " p;\n" "}\n" - : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), - "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), - "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), - "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), - "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), - "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), - "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), - "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), - "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), - "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), - "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), - "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), - "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), - "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), - "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), - "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), - "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), - "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), - "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), - "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), - "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), - "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), - "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), - "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), - "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), - "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), - "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), - "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), - "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), - "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119) - : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) + : "l"(desc_a), "l"(desc_b), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); + "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x240x16_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x8x32_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -// GMMA 64x256x16 F32+=F16*F16 -template < - GMMA::Major tnspA, - GMMA::Major tnspB, - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -struct MMA_64x256x16_F32F16F16_SS +// GMMA 64x16x32 TN S32+=S8*S8 +struct MMA_64x16x32_S32S8S8_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = float[128]; + using CRegisters = uint32_t[8]; CUTE_HOST_DEVICE static void fma(uint64_t const& desc_a, uint64_t const& desc_b, - float & d000, float & d001, float & d002, float & d003, - float & d004, float & d005, float & d006, float & d007, - float & d008, float & d009, float & d010, float & d011, - float & d012, float & d013, float & d014, float & d015, - float & d016, float & d017, float & d018, float & d019, - float & d020, float & d021, float & d022, float & d023, - float & d024, float & d025, float & d026, float & d027, - float & d028, float & d029, float & d030, float & d031, - float & d032, float & d033, float & d034, float & d035, - float & d036, float & d037, float & d038, float & d039, - float & d040, float & d041, float & d042, float & d043, - float & d044, float & d045, float & d046, float & d047, - float & d048, float & d049, float & d050, float & d051, - float & d052, float & d053, float & d054, float & d055, - float & d056, float & d057, float & d058, float & d059, - float & d060, float & d061, float & d062, float & d063, - float & d064, float & d065, float & d066, float & d067, - float & d068, float & d069, float & d070, float & d071, - float & d072, float & d073, float & d074, float & d075, - float & d076, float & d077, float & d078, float & d079, - float & d080, float & d081, float & d082, float & d083, - float & d084, float & d085, float & d086, float & d087, - float & d088, float & d089, float & d090, float & d091, - float & d092, float & d093, float & d094, float & d095, - float & d096, float & d097, float & d098, float & d099, - float & d100, float & d101, float & d102, float & d103, - float & d104, float & d105, float & d106, float & d107, - float & d108, float & d109, float & d110, float & d111, - float & d112, float & d113, float & d114, float & d115, - float & d116, float & d117, float & d118, float & d119, - float & d120, float & d121, float & d122, float & d123, - float & d124, float & d125, float & d126, float & d127, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) @@ -5017,213 +4629,81 @@ struct MMA_64x256x16_F32F16F16_SS asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %130, 0;\n" - "wgmma.mma_async.sync.aligned.m64n256k16.f32.f16.f16 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87, " - " %88, %89, %90, %91, %92, %93, %94, %95, " - " %96, %97, %98, %99, %100, %101, %102, %103, " - " %104, %105, %106, %107, %108, %109, %110, %111, " - " %112, %113, %114, %115, %116, %117, %118, %119, " - " %120, %121, %122, %123, %124, %125, %126, %127}," - " %128," - " %129," - " p, %131, %132, %133, %134;\n" + "setp.ne.b32 p, %10, 0;\n" + "wgmma.mma_async.sync.aligned.m64n16k32.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + " %8," + " %9," + " p;\n" "}\n" - : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), - "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), - "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), - "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), - "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), - "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), - "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), - "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), - "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), - "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), - "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), - "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), - "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), - "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), - "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), - "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), - "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), - "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), - "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), - "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), - "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), - "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), - "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), - "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), - "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), - "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), - "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), - "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), - "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), - "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), - "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123), - "+f"(d124), "+f"(d125), "+f"(d126), "+f"(d127) + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) : "l"(desc_a), "l"(desc_b), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); + "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x256x16_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x16x32_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// GMMA 64x256x16 F32+=F16*F16 -template < - GMMA::Major tnspA, - GMMA::Major tnspB, - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -struct MMA_64x256x16_F32F16F16_RS +// GMMA 64x16x32 TN S32+=S8*S8 +struct MMA_64x16x32_S32S8S8_SS_TN_SATURATE { using DRegisters = void; - using ARegisters = uint32_t[4]; + using ARegisters = uint64_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = float[128]; - - static_assert(tnspA == GMMA::Major::K, - "Register source operand A must have K major layout."); + using CRegisters = uint32_t[8]; CUTE_HOST_DEVICE static void - fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + fma(uint64_t const& desc_a, uint64_t const& desc_b, - float & d000, float & d001, float & d002, float & d003, - float & d004, float & d005, float & d006, float & d007, - float & d008, float & d009, float & d010, float & d011, - float & d012, float & d013, float & d014, float & d015, - float & d016, float & d017, float & d018, float & d019, - float & d020, float & d021, float & d022, float & d023, - float & d024, float & d025, float & d026, float & d027, - float & d028, float & d029, float & d030, float & d031, - float & d032, float & d033, float & d034, float & d035, - float & d036, float & d037, float & d038, float & d039, - float & d040, float & d041, float & d042, float & d043, - float & d044, float & d045, float & d046, float & d047, - float & d048, float & d049, float & d050, float & d051, - float & d052, float & d053, float & d054, float & d055, - float & d056, float & d057, float & d058, float & d059, - float & d060, float & d061, float & d062, float & d063, - float & d064, float & d065, float & d066, float & d067, - float & d068, float & d069, float & d070, float & d071, - float & d072, float & d073, float & d074, float & d075, - float & d076, float & d077, float & d078, float & d079, - float & d080, float & d081, float & d082, float & d083, - float & d084, float & d085, float & d086, float & d087, - float & d088, float & d089, float & d090, float & d091, - float & d092, float & d093, float & d094, float & d095, - float & d096, float & d097, float & d098, float & d099, - float & d100, float & d101, float & d102, float & d103, - float & d104, float & d105, float & d106, float & d107, - float & d108, float & d109, float & d110, float & d111, - float & d112, float & d113, float & d114, float & d115, - float & d116, float & d117, float & d118, float & d119, - float & d120, float & d121, float & d122, float & d123, - float & d124, float & d125, float & d126, float & d127, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %133, 0;\n" - "wgmma.mma_async.sync.aligned.m64n256k16.f32.f16.f16 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87, " - " %88, %89, %90, %91, %92, %93, %94, %95, " - " %96, %97, %98, %99, %100, %101, %102, %103, " - " %104, %105, %106, %107, %108, %109, %110, %111, " - " %112, %113, %114, %115, %116, %117, %118, %119, " - " %120, %121, %122, %123, %124, %125, %126, %127}," - "{%128, %129, %130, %131}," - " %132," - " p, %134, %135, %136;\n" + "setp.ne.b32 p, %10, 0;\n" + "wgmma.mma_async.sync.aligned.m64n16k32.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + " %8," + " %9," + " p;\n" "}\n" - : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), - "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), - "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), - "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), - "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), - "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), - "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), - "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), - "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), - "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), - "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), - "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), - "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), - "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), - "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), - "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), - "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), - "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), - "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), - "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), - "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), - "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), - "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), - "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), - "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), - "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), - "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), - "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), - "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), - "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), - "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123), - "+f"(d124), "+f"(d125), "+f"(d126), "+f"(d127) - : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) + : "l"(desc_a), "l"(desc_b), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); + "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x256x16_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x16x32_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// GMMA 64x8x16 F32+=BF16*BF16 -template < - GMMA::Major tnspA, - GMMA::Major tnspB, - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -struct MMA_64x8x16_F32BF16BF16_SS +// GMMA 64x32x32 TN S32+=S8*S8 +struct MMA_64x32x32_S32S8S8_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = float[4]; + using CRegisters = uint32_t[16]; CUTE_HOST_DEVICE static void fma(uint64_t const& desc_a, uint64_t const& desc_b, - float & d0, float & d1, float & d2, float & d3, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) @@ -5231,91 +4711,93 @@ struct MMA_64x8x16_F32BF16BF16_SS asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %6, 0;\n" - "wgmma.mma_async.sync.aligned.m64n8k16.f32.bf16.bf16 " - "{%0, %1, %2, %3}," - " %4," - " %5," - " p, %7, %8, %9, %10;\n" + "setp.ne.b32 p, %18, 0;\n" + "wgmma.mma_async.sync.aligned.m64n32k32.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + " %16," + " %17," + " p;\n" "}\n" - : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3) + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) : "l"(desc_a), "l"(desc_b), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); + "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x8x16_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x32x32_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// GMMA 64x8x16 F32+=BF16*BF16 -template < - GMMA::Major tnspA, - GMMA::Major tnspB, - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -struct MMA_64x8x16_F32BF16BF16_RS +// GMMA 64x32x32 TN S32+=S8*S8 +struct MMA_64x32x32_S32S8S8_SS_TN_SATURATE { using DRegisters = void; - using ARegisters = uint32_t[4]; + using ARegisters = uint64_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = float[4]; - - static_assert(tnspA == GMMA::Major::K, - "Register source operand A must have K major layout."); + using CRegisters = uint32_t[16]; CUTE_HOST_DEVICE static void - fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + fma(uint64_t const& desc_a, uint64_t const& desc_b, - float & d0, float & d1, float & d2, float & d3, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %9, 0;\n" - "wgmma.mma_async.sync.aligned.m64n8k16.f32.bf16.bf16 " - "{%0, %1, %2, %3}," - "{%4, %5, %6, %7}," - " %8," - " p, %10, %11, %12;\n" + "setp.ne.b32 p, %18, 0;\n" + "wgmma.mma_async.sync.aligned.m64n32k32.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + " %16," + " %17," + " p;\n" "}\n" - : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3) - : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) + : "l"(desc_a), "l"(desc_b), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); + "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x8x16_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x32x32_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// GMMA 64x16x16 F32+=BF16*BF16 -template < - GMMA::Major tnspA, - GMMA::Major tnspB, - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -struct MMA_64x16x16_F32BF16BF16_SS +// GMMA 64x64x32 TN S32+=S8*S8 +struct MMA_64x64x32_S32S8S8_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = float[8]; + using CRegisters = uint32_t[32]; CUTE_HOST_DEVICE static void fma(uint64_t const& desc_a, uint64_t const& desc_b, - float & d0, float & d1, float & d2, float & d3, - float & d4, float & d5, float & d6, float & d7, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) @@ -5323,96 +4805,113 @@ struct MMA_64x16x16_F32BF16BF16_SS asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %10, 0;\n" - "wgmma.mma_async.sync.aligned.m64n16k16.f32.bf16.bf16 " - "{%0, %1, %2, %3, %4, %5, %6, %7}," - " %8," - " %9," - " p, %11, %12, %13, %14;\n" + "setp.ne.b32 p, %34, 0;\n" + "wgmma.mma_async.sync.aligned.m64n64k32.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + " %32," + " %33," + " p;\n" "}\n" - : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3), - "+f"(d4), "+f"(d5), "+f"(d6), "+f"(d7) + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) : "l"(desc_a), "l"(desc_b), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); + "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x16x16_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x64x32_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// GMMA 64x16x16 F32+=BF16*BF16 -template < - GMMA::Major tnspA, - GMMA::Major tnspB, - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -struct MMA_64x16x16_F32BF16BF16_RS +// GMMA 64x64x32 TN S32+=S8*S8 +struct MMA_64x64x32_S32S8S8_SS_TN_SATURATE { using DRegisters = void; - using ARegisters = uint32_t[4]; + using ARegisters = uint64_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = float[8]; - - static_assert(tnspA == GMMA::Major::K, - "Register source operand A must have K major layout."); + using CRegisters = uint32_t[32]; CUTE_HOST_DEVICE static void - fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + fma(uint64_t const& desc_a, uint64_t const& desc_b, - float & d0, float & d1, float & d2, float & d3, - float & d4, float & d5, float & d6, float & d7, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %13, 0;\n" - "wgmma.mma_async.sync.aligned.m64n16k16.f32.bf16.bf16 " - "{%0, %1, %2, %3, %4, %5, %6, %7}," - "{%8, %9, %10, %11}," - " %12," - " p, %14, %15, %16;\n" + "setp.ne.b32 p, %34, 0;\n" + "wgmma.mma_async.sync.aligned.m64n64k32.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + " %32," + " %33," + " p;\n" "}\n" - : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3), - "+f"(d4), "+f"(d5), "+f"(d6), "+f"(d7) - : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) + : "l"(desc_a), "l"(desc_b), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); + "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x16x16_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x64x32_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// GMMA 64x32x16 F32+=BF16*BF16 -template < - GMMA::Major tnspA, - GMMA::Major tnspB, - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -struct MMA_64x32x16_F32BF16BF16_SS +// GMMA 64x96x32 TN S32+=S8*S8 +struct MMA_64x96x32_S32S8S8_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = float[16]; + using CRegisters = uint32_t[48]; CUTE_HOST_DEVICE static void fma(uint64_t const& desc_a, uint64_t const& desc_b, - float & d00, float & d01, float & d02, float & d03, - float & d04, float & d05, float & d06, float & d07, - float & d08, float & d09, float & d10, float & d11, - float & d12, float & d13, float & d14, float & d15, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) @@ -5420,26905 +4919,113 @@ struct MMA_64x32x16_F32BF16BF16_SS asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %18, 0;\n" - "wgmma.mma_async.sync.aligned.m64n32k16.f32.bf16.bf16 " + "setp.ne.b32 p, %50, 0;\n" + "wgmma.mma_async.sync.aligned.m64n96k32.s32.s8.s8 " "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15}," - " %16," - " %17," - " p, %19, %20, %21, %22;\n" + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + " %48," + " %49," + " p;\n" "}\n" - : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), - "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), - "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), - "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15) + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) : "l"(desc_a), "l"(desc_b), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); + "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x32x16_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x96x32_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// GMMA 64x32x16 F32+=BF16*BF16 -template < - GMMA::Major tnspA, - GMMA::Major tnspB, - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -struct MMA_64x32x16_F32BF16BF16_RS +// GMMA 64x96x32 TN S32+=S8*S8 +struct MMA_64x96x32_S32S8S8_SS_TN_SATURATE { using DRegisters = void; - using ARegisters = uint32_t[4]; + using ARegisters = uint64_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = float[16]; - - static_assert(tnspA == GMMA::Major::K, - "Register source operand A must have K major layout."); + using CRegisters = uint32_t[48]; CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + fma(uint64_t const& desc_a, uint64_t const& desc_b, - float & d00, float & d01, float & d02, float & d03, - float & d04, float & d05, float & d06, float & d07, - float & d08, float & d09, float & d10, float & d11, - float & d12, float & d13, float & d14, float & d15, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %21, 0;\n" - "wgmma.mma_async.sync.aligned.m64n32k16.f32.bf16.bf16 " + "setp.ne.b32 p, %50, 0;\n" + "wgmma.mma_async.sync.aligned.m64n96k32.s32.s8.s8.satfinite " "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15}," - "{%16, %17, %18, %19}," - " %20," - " p, %22, %23, %24;\n" - "}\n" - : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), - "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), - "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), - "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), - "l"(desc_b), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x32x16_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x40x16 F32+=BF16*BF16 -template < - GMMA::Major tnspA, - GMMA::Major tnspB, - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -struct MMA_64x40x16_F32BF16BF16_SS -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = float[20]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - float & d00, float & d01, float & d02, float & d03, - float & d04, float & d05, float & d06, float & d07, - float & d08, float & d09, float & d10, float & d11, - float & d12, float & d13, float & d14, float & d15, - float & d16, float & d17, float & d18, float & d19, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %22, 0;\n" - "wgmma.mma_async.sync.aligned.m64n40k16.f32.bf16.bf16 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19}," - " %20," - " %21," - " p, %23, %24, %25, %26;\n" - "}\n" - : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), - "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), - "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), - "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), - "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19) - : "l"(desc_a), - "l"(desc_b), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x40x16_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x48x16 F32+=BF16*BF16 -template < - GMMA::Major tnspA, - GMMA::Major tnspB, - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -struct MMA_64x48x16_F32BF16BF16_SS -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = float[24]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - float & d00, float & d01, float & d02, float & d03, - float & d04, float & d05, float & d06, float & d07, - float & d08, float & d09, float & d10, float & d11, - float & d12, float & d13, float & d14, float & d15, - float & d16, float & d17, float & d18, float & d19, - float & d20, float & d21, float & d22, float & d23, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %26, 0;\n" - "wgmma.mma_async.sync.aligned.m64n48k16.f32.bf16.bf16 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23}," - " %24," - " %25," - " p, %27, %28, %29, %30;\n" - "}\n" - : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), - "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), - "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), - "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), - "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), - "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23) - : "l"(desc_a), - "l"(desc_b), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x48x16_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x48x16 F32+=BF16*BF16 -template < - GMMA::Major tnspA, - GMMA::Major tnspB, - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -struct MMA_64x48x16_F32BF16BF16_RS -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using BRegisters = uint64_t[1]; - using CRegisters = float[24]; - - static_assert(tnspA == GMMA::Major::K, - "Register source operand A must have K major layout."); - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, - uint64_t const& desc_b, - float & d00, float & d01, float & d02, float & d03, - float & d04, float & d05, float & d06, float & d07, - float & d08, float & d09, float & d10, float & d11, - float & d12, float & d13, float & d14, float & d15, - float & d16, float & d17, float & d18, float & d19, - float & d20, float & d21, float & d22, float & d23, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %29, 0;\n" - "wgmma.mma_async.sync.aligned.m64n48k16.f32.bf16.bf16 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23}," - "{%24, %25, %26, %27}," - " %28," - " p, %30, %31, %32;\n" - "}\n" - : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), - "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), - "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), - "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), - "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), - "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), - "l"(desc_b), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x48x16_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// GMMA 64x64x16 F32+=BF16*BF16 -template < - GMMA::Major tnspA, - GMMA::Major tnspB, - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -struct MMA_64x64x16_F32BF16BF16_SS -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = float[32]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - float & d00, float & d01, float & d02, float & d03, - float & d04, float & d05, float & d06, float & d07, - float & d08, float & d09, float & d10, float & d11, - float & d12, float & d13, float & d14, float & d15, - float & d16, float & d17, float & d18, float & d19, - float & d20, float & d21, float & d22, float & d23, - float & d24, float & d25, float & d26, float & d27, - float & d28, float & d29, float & d30, float & d31, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %34, 0;\n" - "wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31}," - " %32," - " %33," - " p, %35, %36, %37, %38;\n" - "}\n" - : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), - "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), - "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), - "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), - "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), - "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), - "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), - "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31) - : "l"(desc_a), - "l"(desc_b), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x64x16_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// GMMA 64x64x16 F32+=BF16*BF16 -template < - GMMA::Major tnspA, - GMMA::Major tnspB, - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -struct MMA_64x64x16_F32BF16BF16_RS -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using BRegisters = uint64_t[1]; - using CRegisters = float[32]; - - static_assert(tnspA == GMMA::Major::K, - "Register source operand A must have K major layout."); - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, - uint64_t const& desc_b, - float & d00, float & d01, float & d02, float & d03, - float & d04, float & d05, float & d06, float & d07, - float & d08, float & d09, float & d10, float & d11, - float & d12, float & d13, float & d14, float & d15, - float & d16, float & d17, float & d18, float & d19, - float & d20, float & d21, float & d22, float & d23, - float & d24, float & d25, float & d26, float & d27, - float & d28, float & d29, float & d30, float & d31, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %37, 0;\n" - "wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31}," - "{%32, %33, %34, %35}," - " %36," - " p, %38, %39, %40;\n" - "}\n" - : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), - "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), - "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), - "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), - "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), - "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), - "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), - "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), - "l"(desc_b), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x64x16_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x80x16 F32+=BF16*BF16 -template < - GMMA::Major tnspA, - GMMA::Major tnspB, - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -struct MMA_64x80x16_F32BF16BF16_SS -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = float[40]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - float & d00, float & d01, float & d02, float & d03, - float & d04, float & d05, float & d06, float & d07, - float & d08, float & d09, float & d10, float & d11, - float & d12, float & d13, float & d14, float & d15, - float & d16, float & d17, float & d18, float & d19, - float & d20, float & d21, float & d22, float & d23, - float & d24, float & d25, float & d26, float & d27, - float & d28, float & d29, float & d30, float & d31, - float & d32, float & d33, float & d34, float & d35, - float & d36, float & d37, float & d38, float & d39, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %42, 0;\n" - "wgmma.mma_async.sync.aligned.m64n80k16.f32.bf16.bf16 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39}," - " %40," - " %41," - " p, %43, %44, %45, %46;\n" - "}\n" - : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), - "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), - "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), - "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), - "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), - "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), - "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), - "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), - "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), - "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39) - : "l"(desc_a), - "l"(desc_b), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x80x16_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x80x16 F32+=BF16*BF16 -template < - GMMA::Major tnspA, - GMMA::Major tnspB, - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -struct MMA_64x80x16_F32BF16BF16_RS -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using BRegisters = uint64_t[1]; - using CRegisters = float[40]; - - static_assert(tnspA == GMMA::Major::K, - "Register source operand A must have K major layout."); - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, - uint64_t const& desc_b, - float & d00, float & d01, float & d02, float & d03, - float & d04, float & d05, float & d06, float & d07, - float & d08, float & d09, float & d10, float & d11, - float & d12, float & d13, float & d14, float & d15, - float & d16, float & d17, float & d18, float & d19, - float & d20, float & d21, float & d22, float & d23, - float & d24, float & d25, float & d26, float & d27, - float & d28, float & d29, float & d30, float & d31, - float & d32, float & d33, float & d34, float & d35, - float & d36, float & d37, float & d38, float & d39, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %45, 0;\n" - "wgmma.mma_async.sync.aligned.m64n80k16.f32.bf16.bf16 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39}," - "{%40, %41, %42, %43}," - " %44," - " p, %46, %47, %48;\n" - "}\n" - : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), - "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), - "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), - "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), - "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), - "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), - "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), - "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), - "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), - "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), - "l"(desc_b), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x80x16_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// GMMA 64x96x16 F32+=BF16*BF16 -template < - GMMA::Major tnspA, - GMMA::Major tnspB, - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -struct MMA_64x96x16_F32BF16BF16_SS -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = float[48]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - float & d00, float & d01, float & d02, float & d03, - float & d04, float & d05, float & d06, float & d07, - float & d08, float & d09, float & d10, float & d11, - float & d12, float & d13, float & d14, float & d15, - float & d16, float & d17, float & d18, float & d19, - float & d20, float & d21, float & d22, float & d23, - float & d24, float & d25, float & d26, float & d27, - float & d28, float & d29, float & d30, float & d31, - float & d32, float & d33, float & d34, float & d35, - float & d36, float & d37, float & d38, float & d39, - float & d40, float & d41, float & d42, float & d43, - float & d44, float & d45, float & d46, float & d47, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %50, 0;\n" - "wgmma.mma_async.sync.aligned.m64n96k16.f32.bf16.bf16 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47}," - " %48," - " %49," - " p, %51, %52, %53, %54;\n" - "}\n" - : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), - "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), - "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), - "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), - "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), - "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), - "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), - "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), - "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), - "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), - "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), - "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47) - : "l"(desc_a), - "l"(desc_b), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x96x16_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// GMMA 64x96x16 F32+=BF16*BF16 -template < - GMMA::Major tnspA, - GMMA::Major tnspB, - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -struct MMA_64x96x16_F32BF16BF16_RS -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using BRegisters = uint64_t[1]; - using CRegisters = float[48]; - - static_assert(tnspA == GMMA::Major::K, - "Register source operand A must have K major layout."); - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, - uint64_t const& desc_b, - float & d00, float & d01, float & d02, float & d03, - float & d04, float & d05, float & d06, float & d07, - float & d08, float & d09, float & d10, float & d11, - float & d12, float & d13, float & d14, float & d15, - float & d16, float & d17, float & d18, float & d19, - float & d20, float & d21, float & d22, float & d23, - float & d24, float & d25, float & d26, float & d27, - float & d28, float & d29, float & d30, float & d31, - float & d32, float & d33, float & d34, float & d35, - float & d36, float & d37, float & d38, float & d39, - float & d40, float & d41, float & d42, float & d43, - float & d44, float & d45, float & d46, float & d47, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %53, 0;\n" - "wgmma.mma_async.sync.aligned.m64n96k16.f32.bf16.bf16 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47}," - "{%48, %49, %50, %51}," - " %52," - " p, %54, %55, %56;\n" - "}\n" - : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), - "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), - "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), - "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), - "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), - "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), - "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), - "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), - "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), - "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), - "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), - "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), - "l"(desc_b), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x96x16_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x112x16 F32+=BF16*BF16 -template < - GMMA::Major tnspA, - GMMA::Major tnspB, - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -struct MMA_64x112x16_F32BF16BF16_SS -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = float[56]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - float & d00, float & d01, float & d02, float & d03, - float & d04, float & d05, float & d06, float & d07, - float & d08, float & d09, float & d10, float & d11, - float & d12, float & d13, float & d14, float & d15, - float & d16, float & d17, float & d18, float & d19, - float & d20, float & d21, float & d22, float & d23, - float & d24, float & d25, float & d26, float & d27, - float & d28, float & d29, float & d30, float & d31, - float & d32, float & d33, float & d34, float & d35, - float & d36, float & d37, float & d38, float & d39, - float & d40, float & d41, float & d42, float & d43, - float & d44, float & d45, float & d46, float & d47, - float & d48, float & d49, float & d50, float & d51, - float & d52, float & d53, float & d54, float & d55, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %58, 0;\n" - "wgmma.mma_async.sync.aligned.m64n112k16.f32.bf16.bf16 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55}," - " %56," - " %57," - " p, %59, %60, %61, %62;\n" - "}\n" - : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), - "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), - "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), - "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), - "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), - "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), - "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), - "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), - "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), - "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), - "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), - "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), - "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), - "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55) - : "l"(desc_a), - "l"(desc_b), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x112x16_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x112x16 F32+=BF16*BF16 -template < - GMMA::Major tnspA, - GMMA::Major tnspB, - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -struct MMA_64x112x16_F32BF16BF16_RS -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using BRegisters = uint64_t[1]; - using CRegisters = float[56]; - - static_assert(tnspA == GMMA::Major::K, - "Register source operand A must have K major layout."); - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, - uint64_t const& desc_b, - float & d00, float & d01, float & d02, float & d03, - float & d04, float & d05, float & d06, float & d07, - float & d08, float & d09, float & d10, float & d11, - float & d12, float & d13, float & d14, float & d15, - float & d16, float & d17, float & d18, float & d19, - float & d20, float & d21, float & d22, float & d23, - float & d24, float & d25, float & d26, float & d27, - float & d28, float & d29, float & d30, float & d31, - float & d32, float & d33, float & d34, float & d35, - float & d36, float & d37, float & d38, float & d39, - float & d40, float & d41, float & d42, float & d43, - float & d44, float & d45, float & d46, float & d47, - float & d48, float & d49, float & d50, float & d51, - float & d52, float & d53, float & d54, float & d55, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %61, 0;\n" - "wgmma.mma_async.sync.aligned.m64n112k16.f32.bf16.bf16 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55}," - "{%56, %57, %58, %59}," - " %60," - " p, %62, %63, %64;\n" - "}\n" - : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), - "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), - "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), - "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), - "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), - "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), - "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), - "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), - "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), - "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), - "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), - "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), - "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), - "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), - "l"(desc_b), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x112x16_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// GMMA 64x128x16 F32+=BF16*BF16 -template < - GMMA::Major tnspA, - GMMA::Major tnspB, - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -struct MMA_64x128x16_F32BF16BF16_SS -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = float[64]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - float & d00, float & d01, float & d02, float & d03, - float & d04, float & d05, float & d06, float & d07, - float & d08, float & d09, float & d10, float & d11, - float & d12, float & d13, float & d14, float & d15, - float & d16, float & d17, float & d18, float & d19, - float & d20, float & d21, float & d22, float & d23, - float & d24, float & d25, float & d26, float & d27, - float & d28, float & d29, float & d30, float & d31, - float & d32, float & d33, float & d34, float & d35, - float & d36, float & d37, float & d38, float & d39, - float & d40, float & d41, float & d42, float & d43, - float & d44, float & d45, float & d46, float & d47, - float & d48, float & d49, float & d50, float & d51, - float & d52, float & d53, float & d54, float & d55, - float & d56, float & d57, float & d58, float & d59, - float & d60, float & d61, float & d62, float & d63, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %66, 0;\n" - "wgmma.mma_async.sync.aligned.m64n128k16.f32.bf16.bf16 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63}," - " %64," - " %65," - " p, %67, %68, %69, %70;\n" - "}\n" - : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), - "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), - "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), - "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), - "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), - "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), - "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), - "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), - "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), - "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), - "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), - "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), - "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), - "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), - "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), - "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63) - : "l"(desc_a), - "l"(desc_b), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x128x16_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// GMMA 64x128x16 F32+=BF16*BF16 -template < - GMMA::Major tnspA, - GMMA::Major tnspB, - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -struct MMA_64x128x16_F32BF16BF16_RS -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using BRegisters = uint64_t[1]; - using CRegisters = float[64]; - - static_assert(tnspA == GMMA::Major::K, - "Register source operand A must have K major layout."); - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, - uint64_t const& desc_b, - float & d00, float & d01, float & d02, float & d03, - float & d04, float & d05, float & d06, float & d07, - float & d08, float & d09, float & d10, float & d11, - float & d12, float & d13, float & d14, float & d15, - float & d16, float & d17, float & d18, float & d19, - float & d20, float & d21, float & d22, float & d23, - float & d24, float & d25, float & d26, float & d27, - float & d28, float & d29, float & d30, float & d31, - float & d32, float & d33, float & d34, float & d35, - float & d36, float & d37, float & d38, float & d39, - float & d40, float & d41, float & d42, float & d43, - float & d44, float & d45, float & d46, float & d47, - float & d48, float & d49, float & d50, float & d51, - float & d52, float & d53, float & d54, float & d55, - float & d56, float & d57, float & d58, float & d59, - float & d60, float & d61, float & d62, float & d63, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %69, 0;\n" - "wgmma.mma_async.sync.aligned.m64n128k16.f32.bf16.bf16 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63}," - "{%64, %65, %66, %67}," - " %68," - " p, %70, %71, %72;\n" - "}\n" - : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), - "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), - "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), - "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), - "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), - "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), - "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), - "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), - "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), - "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), - "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), - "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), - "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), - "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), - "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), - "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), - "l"(desc_b), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x128x16_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x144x16 F32+=BF16*BF16 -template < - GMMA::Major tnspA, - GMMA::Major tnspB, - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -struct MMA_64x144x16_F32BF16BF16_SS -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = float[72]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - float & d00, float & d01, float & d02, float & d03, - float & d04, float & d05, float & d06, float & d07, - float & d08, float & d09, float & d10, float & d11, - float & d12, float & d13, float & d14, float & d15, - float & d16, float & d17, float & d18, float & d19, - float & d20, float & d21, float & d22, float & d23, - float & d24, float & d25, float & d26, float & d27, - float & d28, float & d29, float & d30, float & d31, - float & d32, float & d33, float & d34, float & d35, - float & d36, float & d37, float & d38, float & d39, - float & d40, float & d41, float & d42, float & d43, - float & d44, float & d45, float & d46, float & d47, - float & d48, float & d49, float & d50, float & d51, - float & d52, float & d53, float & d54, float & d55, - float & d56, float & d57, float & d58, float & d59, - float & d60, float & d61, float & d62, float & d63, - float & d64, float & d65, float & d66, float & d67, - float & d68, float & d69, float & d70, float & d71, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %74, 0;\n" - "wgmma.mma_async.sync.aligned.m64n144k16.f32.bf16.bf16 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71}," - " %72," - " %73," - " p, %75, %76, %77, %78;\n" - "}\n" - : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), - "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), - "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), - "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), - "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), - "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), - "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), - "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), - "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), - "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), - "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), - "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), - "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), - "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), - "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), - "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), - "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), - "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71) - : "l"(desc_a), - "l"(desc_b), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x144x16_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x144x16 F32+=BF16*BF16 -template < - GMMA::Major tnspA, - GMMA::Major tnspB, - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -struct MMA_64x144x16_F32BF16BF16_RS -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using BRegisters = uint64_t[1]; - using CRegisters = float[72]; - - static_assert(tnspA == GMMA::Major::K, - "Register source operand A must have K major layout."); - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, - uint64_t const& desc_b, - float & d00, float & d01, float & d02, float & d03, - float & d04, float & d05, float & d06, float & d07, - float & d08, float & d09, float & d10, float & d11, - float & d12, float & d13, float & d14, float & d15, - float & d16, float & d17, float & d18, float & d19, - float & d20, float & d21, float & d22, float & d23, - float & d24, float & d25, float & d26, float & d27, - float & d28, float & d29, float & d30, float & d31, - float & d32, float & d33, float & d34, float & d35, - float & d36, float & d37, float & d38, float & d39, - float & d40, float & d41, float & d42, float & d43, - float & d44, float & d45, float & d46, float & d47, - float & d48, float & d49, float & d50, float & d51, - float & d52, float & d53, float & d54, float & d55, - float & d56, float & d57, float & d58, float & d59, - float & d60, float & d61, float & d62, float & d63, - float & d64, float & d65, float & d66, float & d67, - float & d68, float & d69, float & d70, float & d71, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %77, 0;\n" - "wgmma.mma_async.sync.aligned.m64n144k16.f32.bf16.bf16 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71}," - "{%72, %73, %74, %75}," - " %76," - " p, %78, %79, %80;\n" - "}\n" - : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), - "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), - "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), - "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), - "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), - "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), - "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), - "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), - "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), - "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), - "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), - "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), - "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), - "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), - "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), - "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), - "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), - "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), - "l"(desc_b), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x144x16_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x160x16 F32+=BF16*BF16 -template < - GMMA::Major tnspA, - GMMA::Major tnspB, - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -struct MMA_64x160x16_F32BF16BF16_SS -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = float[80]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - float & d00, float & d01, float & d02, float & d03, - float & d04, float & d05, float & d06, float & d07, - float & d08, float & d09, float & d10, float & d11, - float & d12, float & d13, float & d14, float & d15, - float & d16, float & d17, float & d18, float & d19, - float & d20, float & d21, float & d22, float & d23, - float & d24, float & d25, float & d26, float & d27, - float & d28, float & d29, float & d30, float & d31, - float & d32, float & d33, float & d34, float & d35, - float & d36, float & d37, float & d38, float & d39, - float & d40, float & d41, float & d42, float & d43, - float & d44, float & d45, float & d46, float & d47, - float & d48, float & d49, float & d50, float & d51, - float & d52, float & d53, float & d54, float & d55, - float & d56, float & d57, float & d58, float & d59, - float & d60, float & d61, float & d62, float & d63, - float & d64, float & d65, float & d66, float & d67, - float & d68, float & d69, float & d70, float & d71, - float & d72, float & d73, float & d74, float & d75, - float & d76, float & d77, float & d78, float & d79, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %82, 0;\n" - "wgmma.mma_async.sync.aligned.m64n160k16.f32.bf16.bf16 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79}," - " %80," - " %81," - " p, %83, %84, %85, %86;\n" - "}\n" - : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), - "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), - "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), - "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), - "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), - "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), - "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), - "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), - "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), - "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), - "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), - "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), - "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), - "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), - "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), - "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), - "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), - "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), - "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), - "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79) - : "l"(desc_a), - "l"(desc_b), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x160x16_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x160x16 F32+=BF16*BF16 -template < - GMMA::Major tnspA, - GMMA::Major tnspB, - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -struct MMA_64x160x16_F32BF16BF16_RS -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using BRegisters = uint64_t[1]; - using CRegisters = float[80]; - - static_assert(tnspA == GMMA::Major::K, - "Register source operand A must have K major layout."); - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, - uint64_t const& desc_b, - float & d00, float & d01, float & d02, float & d03, - float & d04, float & d05, float & d06, float & d07, - float & d08, float & d09, float & d10, float & d11, - float & d12, float & d13, float & d14, float & d15, - float & d16, float & d17, float & d18, float & d19, - float & d20, float & d21, float & d22, float & d23, - float & d24, float & d25, float & d26, float & d27, - float & d28, float & d29, float & d30, float & d31, - float & d32, float & d33, float & d34, float & d35, - float & d36, float & d37, float & d38, float & d39, - float & d40, float & d41, float & d42, float & d43, - float & d44, float & d45, float & d46, float & d47, - float & d48, float & d49, float & d50, float & d51, - float & d52, float & d53, float & d54, float & d55, - float & d56, float & d57, float & d58, float & d59, - float & d60, float & d61, float & d62, float & d63, - float & d64, float & d65, float & d66, float & d67, - float & d68, float & d69, float & d70, float & d71, - float & d72, float & d73, float & d74, float & d75, - float & d76, float & d77, float & d78, float & d79, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %85, 0;\n" - "wgmma.mma_async.sync.aligned.m64n160k16.f32.bf16.bf16 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79}," - "{%80, %81, %82, %83}," - " %84," - " p, %86, %87, %88;\n" - "}\n" - : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), - "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), - "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), - "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), - "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), - "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), - "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), - "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), - "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), - "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), - "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), - "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), - "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), - "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), - "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), - "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), - "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), - "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), - "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), - "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), - "l"(desc_b), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x160x16_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x176x16 F32+=BF16*BF16 -template < - GMMA::Major tnspA, - GMMA::Major tnspB, - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -struct MMA_64x176x16_F32BF16BF16_SS -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = float[88]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - float & d00, float & d01, float & d02, float & d03, - float & d04, float & d05, float & d06, float & d07, - float & d08, float & d09, float & d10, float & d11, - float & d12, float & d13, float & d14, float & d15, - float & d16, float & d17, float & d18, float & d19, - float & d20, float & d21, float & d22, float & d23, - float & d24, float & d25, float & d26, float & d27, - float & d28, float & d29, float & d30, float & d31, - float & d32, float & d33, float & d34, float & d35, - float & d36, float & d37, float & d38, float & d39, - float & d40, float & d41, float & d42, float & d43, - float & d44, float & d45, float & d46, float & d47, - float & d48, float & d49, float & d50, float & d51, - float & d52, float & d53, float & d54, float & d55, - float & d56, float & d57, float & d58, float & d59, - float & d60, float & d61, float & d62, float & d63, - float & d64, float & d65, float & d66, float & d67, - float & d68, float & d69, float & d70, float & d71, - float & d72, float & d73, float & d74, float & d75, - float & d76, float & d77, float & d78, float & d79, - float & d80, float & d81, float & d82, float & d83, - float & d84, float & d85, float & d86, float & d87, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %90, 0;\n" - "wgmma.mma_async.sync.aligned.m64n176k16.f32.bf16.bf16 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87}," - " %88," - " %89," - " p, %91, %92, %93, %94;\n" - "}\n" - : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), - "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), - "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), - "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), - "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), - "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), - "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), - "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), - "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), - "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), - "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), - "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), - "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), - "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), - "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), - "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), - "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), - "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), - "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), - "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), - "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), - "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87) - : "l"(desc_a), - "l"(desc_b), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x176x16_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x176x16 F32+=BF16*BF16 -template < - GMMA::Major tnspA, - GMMA::Major tnspB, - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -struct MMA_64x176x16_F32BF16BF16_RS -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using BRegisters = uint64_t[1]; - using CRegisters = float[88]; - - static_assert(tnspA == GMMA::Major::K, - "Register source operand A must have K major layout."); - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, - uint64_t const& desc_b, - float & d00, float & d01, float & d02, float & d03, - float & d04, float & d05, float & d06, float & d07, - float & d08, float & d09, float & d10, float & d11, - float & d12, float & d13, float & d14, float & d15, - float & d16, float & d17, float & d18, float & d19, - float & d20, float & d21, float & d22, float & d23, - float & d24, float & d25, float & d26, float & d27, - float & d28, float & d29, float & d30, float & d31, - float & d32, float & d33, float & d34, float & d35, - float & d36, float & d37, float & d38, float & d39, - float & d40, float & d41, float & d42, float & d43, - float & d44, float & d45, float & d46, float & d47, - float & d48, float & d49, float & d50, float & d51, - float & d52, float & d53, float & d54, float & d55, - float & d56, float & d57, float & d58, float & d59, - float & d60, float & d61, float & d62, float & d63, - float & d64, float & d65, float & d66, float & d67, - float & d68, float & d69, float & d70, float & d71, - float & d72, float & d73, float & d74, float & d75, - float & d76, float & d77, float & d78, float & d79, - float & d80, float & d81, float & d82, float & d83, - float & d84, float & d85, float & d86, float & d87, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %93, 0;\n" - "wgmma.mma_async.sync.aligned.m64n176k16.f32.bf16.bf16 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87}," - "{%88, %89, %90, %91}," - " %92," - " p, %94, %95, %96;\n" - "}\n" - : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), - "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), - "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), - "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), - "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), - "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), - "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), - "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), - "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), - "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), - "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), - "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), - "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), - "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), - "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), - "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), - "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), - "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), - "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), - "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), - "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), - "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), - "l"(desc_b), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x176x16_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// GMMA 64x192x16 F32+=BF16*BF16 -template < - GMMA::Major tnspA, - GMMA::Major tnspB, - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -struct MMA_64x192x16_F32BF16BF16_SS -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = float[96]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - float & d00, float & d01, float & d02, float & d03, - float & d04, float & d05, float & d06, float & d07, - float & d08, float & d09, float & d10, float & d11, - float & d12, float & d13, float & d14, float & d15, - float & d16, float & d17, float & d18, float & d19, - float & d20, float & d21, float & d22, float & d23, - float & d24, float & d25, float & d26, float & d27, - float & d28, float & d29, float & d30, float & d31, - float & d32, float & d33, float & d34, float & d35, - float & d36, float & d37, float & d38, float & d39, - float & d40, float & d41, float & d42, float & d43, - float & d44, float & d45, float & d46, float & d47, - float & d48, float & d49, float & d50, float & d51, - float & d52, float & d53, float & d54, float & d55, - float & d56, float & d57, float & d58, float & d59, - float & d60, float & d61, float & d62, float & d63, - float & d64, float & d65, float & d66, float & d67, - float & d68, float & d69, float & d70, float & d71, - float & d72, float & d73, float & d74, float & d75, - float & d76, float & d77, float & d78, float & d79, - float & d80, float & d81, float & d82, float & d83, - float & d84, float & d85, float & d86, float & d87, - float & d88, float & d89, float & d90, float & d91, - float & d92, float & d93, float & d94, float & d95, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %98, 0;\n" - "wgmma.mma_async.sync.aligned.m64n192k16.f32.bf16.bf16 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87, " - " %88, %89, %90, %91, %92, %93, %94, %95}," - " %96," - " %97," - " p, %99, %100, %101, %102;\n" - "}\n" - : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), - "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), - "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), - "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), - "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), - "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), - "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), - "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), - "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), - "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), - "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), - "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), - "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), - "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), - "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), - "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), - "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), - "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), - "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), - "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), - "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), - "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), - "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91), - "+f"(d92), "+f"(d93), "+f"(d94), "+f"(d95) - : "l"(desc_a), - "l"(desc_b), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x192x16_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// GMMA 64x192x16 F32+=BF16*BF16 -template < - GMMA::Major tnspA, - GMMA::Major tnspB, - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -struct MMA_64x192x16_F32BF16BF16_RS -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using BRegisters = uint64_t[1]; - using CRegisters = float[96]; - - static_assert(tnspA == GMMA::Major::K, - "Register source operand A must have K major layout."); - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, - uint64_t const& desc_b, - float & d00, float & d01, float & d02, float & d03, - float & d04, float & d05, float & d06, float & d07, - float & d08, float & d09, float & d10, float & d11, - float & d12, float & d13, float & d14, float & d15, - float & d16, float & d17, float & d18, float & d19, - float & d20, float & d21, float & d22, float & d23, - float & d24, float & d25, float & d26, float & d27, - float & d28, float & d29, float & d30, float & d31, - float & d32, float & d33, float & d34, float & d35, - float & d36, float & d37, float & d38, float & d39, - float & d40, float & d41, float & d42, float & d43, - float & d44, float & d45, float & d46, float & d47, - float & d48, float & d49, float & d50, float & d51, - float & d52, float & d53, float & d54, float & d55, - float & d56, float & d57, float & d58, float & d59, - float & d60, float & d61, float & d62, float & d63, - float & d64, float & d65, float & d66, float & d67, - float & d68, float & d69, float & d70, float & d71, - float & d72, float & d73, float & d74, float & d75, - float & d76, float & d77, float & d78, float & d79, - float & d80, float & d81, float & d82, float & d83, - float & d84, float & d85, float & d86, float & d87, - float & d88, float & d89, float & d90, float & d91, - float & d92, float & d93, float & d94, float & d95, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %101, 0;\n" - "wgmma.mma_async.sync.aligned.m64n192k16.f32.bf16.bf16 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87, " - " %88, %89, %90, %91, %92, %93, %94, %95}," - "{%96, %97, %98, %99}," - " %100," - " p, %102, %103, %104;\n" - "}\n" - : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), - "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), - "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), - "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), - "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), - "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), - "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), - "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), - "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), - "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), - "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), - "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), - "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), - "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), - "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), - "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), - "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), - "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), - "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), - "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), - "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), - "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), - "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91), - "+f"(d92), "+f"(d93), "+f"(d94), "+f"(d95) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), - "l"(desc_b), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x192x16_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x208x16 F32+=BF16*BF16 -template < - GMMA::Major tnspA, - GMMA::Major tnspB, - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -struct MMA_64x208x16_F32BF16BF16_SS -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = float[104]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - float & d000, float & d001, float & d002, float & d003, - float & d004, float & d005, float & d006, float & d007, - float & d008, float & d009, float & d010, float & d011, - float & d012, float & d013, float & d014, float & d015, - float & d016, float & d017, float & d018, float & d019, - float & d020, float & d021, float & d022, float & d023, - float & d024, float & d025, float & d026, float & d027, - float & d028, float & d029, float & d030, float & d031, - float & d032, float & d033, float & d034, float & d035, - float & d036, float & d037, float & d038, float & d039, - float & d040, float & d041, float & d042, float & d043, - float & d044, float & d045, float & d046, float & d047, - float & d048, float & d049, float & d050, float & d051, - float & d052, float & d053, float & d054, float & d055, - float & d056, float & d057, float & d058, float & d059, - float & d060, float & d061, float & d062, float & d063, - float & d064, float & d065, float & d066, float & d067, - float & d068, float & d069, float & d070, float & d071, - float & d072, float & d073, float & d074, float & d075, - float & d076, float & d077, float & d078, float & d079, - float & d080, float & d081, float & d082, float & d083, - float & d084, float & d085, float & d086, float & d087, - float & d088, float & d089, float & d090, float & d091, - float & d092, float & d093, float & d094, float & d095, - float & d096, float & d097, float & d098, float & d099, - float & d100, float & d101, float & d102, float & d103, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %106, 0;\n" - "wgmma.mma_async.sync.aligned.m64n208k16.f32.bf16.bf16 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87, " - " %88, %89, %90, %91, %92, %93, %94, %95, " - " %96, %97, %98, %99, %100, %101, %102, %103}," - " %104," - " %105," - " p, %107, %108, %109, %110;\n" - "}\n" - : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), - "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), - "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), - "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), - "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), - "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), - "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), - "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), - "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), - "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), - "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), - "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), - "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), - "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), - "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), - "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), - "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), - "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), - "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), - "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), - "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), - "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), - "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), - "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), - "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), - "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103) - : "l"(desc_a), - "l"(desc_b), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x208x16_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x208x16 F32+=BF16*BF16 -template < - GMMA::Major tnspA, - GMMA::Major tnspB, - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -struct MMA_64x208x16_F32BF16BF16_RS -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using BRegisters = uint64_t[1]; - using CRegisters = float[104]; - - static_assert(tnspA == GMMA::Major::K, - "Register source operand A must have K major layout."); - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, - uint64_t const& desc_b, - float & d000, float & d001, float & d002, float & d003, - float & d004, float & d005, float & d006, float & d007, - float & d008, float & d009, float & d010, float & d011, - float & d012, float & d013, float & d014, float & d015, - float & d016, float & d017, float & d018, float & d019, - float & d020, float & d021, float & d022, float & d023, - float & d024, float & d025, float & d026, float & d027, - float & d028, float & d029, float & d030, float & d031, - float & d032, float & d033, float & d034, float & d035, - float & d036, float & d037, float & d038, float & d039, - float & d040, float & d041, float & d042, float & d043, - float & d044, float & d045, float & d046, float & d047, - float & d048, float & d049, float & d050, float & d051, - float & d052, float & d053, float & d054, float & d055, - float & d056, float & d057, float & d058, float & d059, - float & d060, float & d061, float & d062, float & d063, - float & d064, float & d065, float & d066, float & d067, - float & d068, float & d069, float & d070, float & d071, - float & d072, float & d073, float & d074, float & d075, - float & d076, float & d077, float & d078, float & d079, - float & d080, float & d081, float & d082, float & d083, - float & d084, float & d085, float & d086, float & d087, - float & d088, float & d089, float & d090, float & d091, - float & d092, float & d093, float & d094, float & d095, - float & d096, float & d097, float & d098, float & d099, - float & d100, float & d101, float & d102, float & d103, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %109, 0;\n" - "wgmma.mma_async.sync.aligned.m64n208k16.f32.bf16.bf16 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87, " - " %88, %89, %90, %91, %92, %93, %94, %95, " - " %96, %97, %98, %99, %100, %101, %102, %103}," - "{%104, %105, %106, %107}," - " %108," - " p, %110, %111, %112;\n" - "}\n" - : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), - "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), - "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), - "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), - "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), - "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), - "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), - "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), - "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), - "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), - "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), - "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), - "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), - "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), - "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), - "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), - "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), - "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), - "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), - "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), - "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), - "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), - "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), - "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), - "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), - "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103) - : "r"(a000), "r"(a001), "r"(a002), "r"(a003), - "l"(desc_b), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x208x16_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x224x16 F32+=BF16*BF16 -template < - GMMA::Major tnspA, - GMMA::Major tnspB, - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -struct MMA_64x224x16_F32BF16BF16_SS -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = float[112]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - float & d000, float & d001, float & d002, float & d003, - float & d004, float & d005, float & d006, float & d007, - float & d008, float & d009, float & d010, float & d011, - float & d012, float & d013, float & d014, float & d015, - float & d016, float & d017, float & d018, float & d019, - float & d020, float & d021, float & d022, float & d023, - float & d024, float & d025, float & d026, float & d027, - float & d028, float & d029, float & d030, float & d031, - float & d032, float & d033, float & d034, float & d035, - float & d036, float & d037, float & d038, float & d039, - float & d040, float & d041, float & d042, float & d043, - float & d044, float & d045, float & d046, float & d047, - float & d048, float & d049, float & d050, float & d051, - float & d052, float & d053, float & d054, float & d055, - float & d056, float & d057, float & d058, float & d059, - float & d060, float & d061, float & d062, float & d063, - float & d064, float & d065, float & d066, float & d067, - float & d068, float & d069, float & d070, float & d071, - float & d072, float & d073, float & d074, float & d075, - float & d076, float & d077, float & d078, float & d079, - float & d080, float & d081, float & d082, float & d083, - float & d084, float & d085, float & d086, float & d087, - float & d088, float & d089, float & d090, float & d091, - float & d092, float & d093, float & d094, float & d095, - float & d096, float & d097, float & d098, float & d099, - float & d100, float & d101, float & d102, float & d103, - float & d104, float & d105, float & d106, float & d107, - float & d108, float & d109, float & d110, float & d111, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %114, 0;\n" - "wgmma.mma_async.sync.aligned.m64n224k16.f32.bf16.bf16 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87, " - " %88, %89, %90, %91, %92, %93, %94, %95, " - " %96, %97, %98, %99, %100, %101, %102, %103, " - " %104, %105, %106, %107, %108, %109, %110, %111}," - " %112," - " %113," - " p, %115, %116, %117, %118;\n" - "}\n" - : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), - "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), - "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), - "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), - "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), - "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), - "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), - "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), - "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), - "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), - "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), - "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), - "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), - "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), - "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), - "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), - "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), - "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), - "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), - "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), - "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), - "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), - "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), - "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), - "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), - "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), - "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), - "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111) - : "l"(desc_a), - "l"(desc_b), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x224x16_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x224x16 F32+=BF16*BF16 -template < - GMMA::Major tnspA, - GMMA::Major tnspB, - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -struct MMA_64x224x16_F32BF16BF16_RS -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using BRegisters = uint64_t[1]; - using CRegisters = float[112]; - - static_assert(tnspA == GMMA::Major::K, - "Register source operand A must have K major layout."); - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, - uint64_t const& desc_b, - float & d000, float & d001, float & d002, float & d003, - float & d004, float & d005, float & d006, float & d007, - float & d008, float & d009, float & d010, float & d011, - float & d012, float & d013, float & d014, float & d015, - float & d016, float & d017, float & d018, float & d019, - float & d020, float & d021, float & d022, float & d023, - float & d024, float & d025, float & d026, float & d027, - float & d028, float & d029, float & d030, float & d031, - float & d032, float & d033, float & d034, float & d035, - float & d036, float & d037, float & d038, float & d039, - float & d040, float & d041, float & d042, float & d043, - float & d044, float & d045, float & d046, float & d047, - float & d048, float & d049, float & d050, float & d051, - float & d052, float & d053, float & d054, float & d055, - float & d056, float & d057, float & d058, float & d059, - float & d060, float & d061, float & d062, float & d063, - float & d064, float & d065, float & d066, float & d067, - float & d068, float & d069, float & d070, float & d071, - float & d072, float & d073, float & d074, float & d075, - float & d076, float & d077, float & d078, float & d079, - float & d080, float & d081, float & d082, float & d083, - float & d084, float & d085, float & d086, float & d087, - float & d088, float & d089, float & d090, float & d091, - float & d092, float & d093, float & d094, float & d095, - float & d096, float & d097, float & d098, float & d099, - float & d100, float & d101, float & d102, float & d103, - float & d104, float & d105, float & d106, float & d107, - float & d108, float & d109, float & d110, float & d111, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %117, 0;\n" - "wgmma.mma_async.sync.aligned.m64n224k16.f32.bf16.bf16 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87, " - " %88, %89, %90, %91, %92, %93, %94, %95, " - " %96, %97, %98, %99, %100, %101, %102, %103, " - " %104, %105, %106, %107, %108, %109, %110, %111}," - "{%112, %113, %114, %115}," - " %116," - " p, %118, %119, %120;\n" - "}\n" - : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), - "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), - "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), - "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), - "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), - "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), - "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), - "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), - "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), - "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), - "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), - "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), - "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), - "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), - "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), - "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), - "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), - "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), - "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), - "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), - "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), - "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), - "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), - "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), - "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), - "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), - "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), - "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111) - : "r"(a000), "r"(a001), "r"(a002), "r"(a003), - "l"(desc_b), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x224x16_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x240x16 F32+=BF16*BF16 -template < - GMMA::Major tnspA, - GMMA::Major tnspB, - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -struct MMA_64x240x16_F32BF16BF16_SS -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = float[120]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - float & d000, float & d001, float & d002, float & d003, - float & d004, float & d005, float & d006, float & d007, - float & d008, float & d009, float & d010, float & d011, - float & d012, float & d013, float & d014, float & d015, - float & d016, float & d017, float & d018, float & d019, - float & d020, float & d021, float & d022, float & d023, - float & d024, float & d025, float & d026, float & d027, - float & d028, float & d029, float & d030, float & d031, - float & d032, float & d033, float & d034, float & d035, - float & d036, float & d037, float & d038, float & d039, - float & d040, float & d041, float & d042, float & d043, - float & d044, float & d045, float & d046, float & d047, - float & d048, float & d049, float & d050, float & d051, - float & d052, float & d053, float & d054, float & d055, - float & d056, float & d057, float & d058, float & d059, - float & d060, float & d061, float & d062, float & d063, - float & d064, float & d065, float & d066, float & d067, - float & d068, float & d069, float & d070, float & d071, - float & d072, float & d073, float & d074, float & d075, - float & d076, float & d077, float & d078, float & d079, - float & d080, float & d081, float & d082, float & d083, - float & d084, float & d085, float & d086, float & d087, - float & d088, float & d089, float & d090, float & d091, - float & d092, float & d093, float & d094, float & d095, - float & d096, float & d097, float & d098, float & d099, - float & d100, float & d101, float & d102, float & d103, - float & d104, float & d105, float & d106, float & d107, - float & d108, float & d109, float & d110, float & d111, - float & d112, float & d113, float & d114, float & d115, - float & d116, float & d117, float & d118, float & d119, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %122, 0;\n" - "wgmma.mma_async.sync.aligned.m64n240k16.f32.bf16.bf16 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87, " - " %88, %89, %90, %91, %92, %93, %94, %95, " - " %96, %97, %98, %99, %100, %101, %102, %103, " - " %104, %105, %106, %107, %108, %109, %110, %111, " - " %112, %113, %114, %115, %116, %117, %118, %119}," - " %120," - " %121," - " p, %123, %124, %125, %126;\n" - "}\n" - : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), - "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), - "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), - "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), - "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), - "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), - "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), - "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), - "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), - "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), - "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), - "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), - "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), - "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), - "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), - "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), - "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), - "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), - "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), - "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), - "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), - "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), - "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), - "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), - "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), - "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), - "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), - "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), - "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), - "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119) - : "l"(desc_a), - "l"(desc_b), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x240x16_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x240x16 F32+=BF16*BF16 -template < - GMMA::Major tnspA, - GMMA::Major tnspB, - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -struct MMA_64x240x16_F32BF16BF16_RS -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using BRegisters = uint64_t[1]; - using CRegisters = float[120]; - - static_assert(tnspA == GMMA::Major::K, - "Register source operand A must have K major layout."); - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, - uint64_t const& desc_b, - float & d000, float & d001, float & d002, float & d003, - float & d004, float & d005, float & d006, float & d007, - float & d008, float & d009, float & d010, float & d011, - float & d012, float & d013, float & d014, float & d015, - float & d016, float & d017, float & d018, float & d019, - float & d020, float & d021, float & d022, float & d023, - float & d024, float & d025, float & d026, float & d027, - float & d028, float & d029, float & d030, float & d031, - float & d032, float & d033, float & d034, float & d035, - float & d036, float & d037, float & d038, float & d039, - float & d040, float & d041, float & d042, float & d043, - float & d044, float & d045, float & d046, float & d047, - float & d048, float & d049, float & d050, float & d051, - float & d052, float & d053, float & d054, float & d055, - float & d056, float & d057, float & d058, float & d059, - float & d060, float & d061, float & d062, float & d063, - float & d064, float & d065, float & d066, float & d067, - float & d068, float & d069, float & d070, float & d071, - float & d072, float & d073, float & d074, float & d075, - float & d076, float & d077, float & d078, float & d079, - float & d080, float & d081, float & d082, float & d083, - float & d084, float & d085, float & d086, float & d087, - float & d088, float & d089, float & d090, float & d091, - float & d092, float & d093, float & d094, float & d095, - float & d096, float & d097, float & d098, float & d099, - float & d100, float & d101, float & d102, float & d103, - float & d104, float & d105, float & d106, float & d107, - float & d108, float & d109, float & d110, float & d111, - float & d112, float & d113, float & d114, float & d115, - float & d116, float & d117, float & d118, float & d119, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %125, 0;\n" - "wgmma.mma_async.sync.aligned.m64n240k16.f32.bf16.bf16 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87, " - " %88, %89, %90, %91, %92, %93, %94, %95, " - " %96, %97, %98, %99, %100, %101, %102, %103, " - " %104, %105, %106, %107, %108, %109, %110, %111, " - " %112, %113, %114, %115, %116, %117, %118, %119}," - "{%120, %121, %122, %123}," - " %124," - " p, %126, %127, %128;\n" - "}\n" - : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), - "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), - "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), - "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), - "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), - "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), - "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), - "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), - "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), - "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), - "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), - "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), - "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), - "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), - "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), - "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), - "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), - "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), - "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), - "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), - "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), - "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), - "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), - "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), - "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), - "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), - "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), - "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), - "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), - "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119) - : "r"(a000), "r"(a001), "r"(a002), "r"(a003), - "l"(desc_b), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x240x16_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// GMMA 64x256x16 F32+=BF16*BF16 -template < - GMMA::Major tnspA, - GMMA::Major tnspB, - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -struct MMA_64x256x16_F32BF16BF16_SS -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = float[128]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - float & d000, float & d001, float & d002, float & d003, - float & d004, float & d005, float & d006, float & d007, - float & d008, float & d009, float & d010, float & d011, - float & d012, float & d013, float & d014, float & d015, - float & d016, float & d017, float & d018, float & d019, - float & d020, float & d021, float & d022, float & d023, - float & d024, float & d025, float & d026, float & d027, - float & d028, float & d029, float & d030, float & d031, - float & d032, float & d033, float & d034, float & d035, - float & d036, float & d037, float & d038, float & d039, - float & d040, float & d041, float & d042, float & d043, - float & d044, float & d045, float & d046, float & d047, - float & d048, float & d049, float & d050, float & d051, - float & d052, float & d053, float & d054, float & d055, - float & d056, float & d057, float & d058, float & d059, - float & d060, float & d061, float & d062, float & d063, - float & d064, float & d065, float & d066, float & d067, - float & d068, float & d069, float & d070, float & d071, - float & d072, float & d073, float & d074, float & d075, - float & d076, float & d077, float & d078, float & d079, - float & d080, float & d081, float & d082, float & d083, - float & d084, float & d085, float & d086, float & d087, - float & d088, float & d089, float & d090, float & d091, - float & d092, float & d093, float & d094, float & d095, - float & d096, float & d097, float & d098, float & d099, - float & d100, float & d101, float & d102, float & d103, - float & d104, float & d105, float & d106, float & d107, - float & d108, float & d109, float & d110, float & d111, - float & d112, float & d113, float & d114, float & d115, - float & d116, float & d117, float & d118, float & d119, - float & d120, float & d121, float & d122, float & d123, - float & d124, float & d125, float & d126, float & d127, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %130, 0;\n" - "wgmma.mma_async.sync.aligned.m64n256k16.f32.bf16.bf16 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87, " - " %88, %89, %90, %91, %92, %93, %94, %95, " - " %96, %97, %98, %99, %100, %101, %102, %103, " - " %104, %105, %106, %107, %108, %109, %110, %111, " - " %112, %113, %114, %115, %116, %117, %118, %119, " - " %120, %121, %122, %123, %124, %125, %126, %127}," - " %128," - " %129," - " p, %131, %132, %133, %134;\n" - "}\n" - : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), - "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), - "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), - "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), - "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), - "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), - "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), - "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), - "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), - "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), - "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), - "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), - "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), - "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), - "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), - "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), - "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), - "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), - "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), - "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), - "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), - "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), - "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), - "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), - "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), - "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), - "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), - "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), - "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), - "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), - "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123), - "+f"(d124), "+f"(d125), "+f"(d126), "+f"(d127) - : "l"(desc_a), - "l"(desc_b), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x256x16_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// GMMA 64x256x16 F32+=BF16*BF16 -template < - GMMA::Major tnspA, - GMMA::Major tnspB, - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -struct MMA_64x256x16_F32BF16BF16_RS -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using BRegisters = uint64_t[1]; - using CRegisters = float[128]; - - static_assert(tnspA == GMMA::Major::K, - "Register source operand A must have K major layout."); - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, - uint64_t const& desc_b, - float & d000, float & d001, float & d002, float & d003, - float & d004, float & d005, float & d006, float & d007, - float & d008, float & d009, float & d010, float & d011, - float & d012, float & d013, float & d014, float & d015, - float & d016, float & d017, float & d018, float & d019, - float & d020, float & d021, float & d022, float & d023, - float & d024, float & d025, float & d026, float & d027, - float & d028, float & d029, float & d030, float & d031, - float & d032, float & d033, float & d034, float & d035, - float & d036, float & d037, float & d038, float & d039, - float & d040, float & d041, float & d042, float & d043, - float & d044, float & d045, float & d046, float & d047, - float & d048, float & d049, float & d050, float & d051, - float & d052, float & d053, float & d054, float & d055, - float & d056, float & d057, float & d058, float & d059, - float & d060, float & d061, float & d062, float & d063, - float & d064, float & d065, float & d066, float & d067, - float & d068, float & d069, float & d070, float & d071, - float & d072, float & d073, float & d074, float & d075, - float & d076, float & d077, float & d078, float & d079, - float & d080, float & d081, float & d082, float & d083, - float & d084, float & d085, float & d086, float & d087, - float & d088, float & d089, float & d090, float & d091, - float & d092, float & d093, float & d094, float & d095, - float & d096, float & d097, float & d098, float & d099, - float & d100, float & d101, float & d102, float & d103, - float & d104, float & d105, float & d106, float & d107, - float & d108, float & d109, float & d110, float & d111, - float & d112, float & d113, float & d114, float & d115, - float & d116, float & d117, float & d118, float & d119, - float & d120, float & d121, float & d122, float & d123, - float & d124, float & d125, float & d126, float & d127, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %133, 0;\n" - "wgmma.mma_async.sync.aligned.m64n256k16.f32.bf16.bf16 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87, " - " %88, %89, %90, %91, %92, %93, %94, %95, " - " %96, %97, %98, %99, %100, %101, %102, %103, " - " %104, %105, %106, %107, %108, %109, %110, %111, " - " %112, %113, %114, %115, %116, %117, %118, %119, " - " %120, %121, %122, %123, %124, %125, %126, %127}," - "{%128, %129, %130, %131}," - " %132," - " p, %134, %135, %136;\n" - "}\n" - : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), - "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), - "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), - "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), - "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), - "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), - "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), - "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), - "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), - "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), - "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), - "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), - "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), - "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), - "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), - "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), - "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), - "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), - "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), - "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), - "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), - "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), - "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), - "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), - "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), - "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), - "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), - "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), - "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), - "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), - "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123), - "+f"(d124), "+f"(d125), "+f"(d126), "+f"(d127) - : "r"(a000), "r"(a001), "r"(a002), "r"(a003), - "l"(desc_b), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x256x16_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// GMMA 64x8x8 TN F32+=TF32*TF32 -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -struct MMA_64x8x8_F32TF32TF32_SS_TN -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = float[4]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - float & d0, float & d1, float & d2, float & d3, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %6, 0;\n" - "wgmma.mma_async.sync.aligned.m64n8k8.f32.tf32.tf32 " - "{%0, %1, %2, %3}," - " %4," - " %5," - " p, %7, %8;\n" - "}\n" - : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3) - : "l"(desc_a), - "l"(desc_b), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x8x8_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// GMMA 64x8x8 TN F32+=TF32*TF32 -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -struct MMA_64x8x8_F32TF32TF32_RS_TN -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using BRegisters = uint64_t[1]; - using CRegisters = float[4]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, - uint64_t const& desc_b, - float & d0, float & d1, float & d2, float & d3, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %9, 0;\n" - "wgmma.mma_async.sync.aligned.m64n8k8.f32.tf32.tf32 " - "{%0, %1, %2, %3}," - "{%4, %5, %6, %7}," - " %8," - " p, %10, %11;\n" - "}\n" - : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3) - : "r"(a0), "r"(a1), "r"(a2), "r"(a3), - "l"(desc_b), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x8x8_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// GMMA 64x16x8 TN F32+=TF32*TF32 -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -struct MMA_64x16x8_F32TF32TF32_SS_TN -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = float[8]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - float & d0, float & d1, float & d2, float & d3, - float & d4, float & d5, float & d6, float & d7, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %10, 0;\n" - "wgmma.mma_async.sync.aligned.m64n16k8.f32.tf32.tf32 " - "{%0, %1, %2, %3, %4, %5, %6, %7}," - " %8," - " %9," - " p, %11, %12;\n" - "}\n" - : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3), - "+f"(d4), "+f"(d5), "+f"(d6), "+f"(d7) - : "l"(desc_a), - "l"(desc_b), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x16x8_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// GMMA 64x16x8 TN F32+=TF32*TF32 -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -struct MMA_64x16x8_F32TF32TF32_RS_TN -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using BRegisters = uint64_t[1]; - using CRegisters = float[8]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, - uint64_t const& desc_b, - float & d0, float & d1, float & d2, float & d3, - float & d4, float & d5, float & d6, float & d7, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %13, 0;\n" - "wgmma.mma_async.sync.aligned.m64n16k8.f32.tf32.tf32 " - "{%0, %1, %2, %3, %4, %5, %6, %7}," - "{%8, %9, %10, %11}," - " %12," - " p, %14, %15;\n" - "}\n" - : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3), - "+f"(d4), "+f"(d5), "+f"(d6), "+f"(d7) - : "r"(a0), "r"(a1), "r"(a2), "r"(a3), - "l"(desc_b), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x16x8_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// GMMA 64x32x8 TN F32+=TF32*TF32 -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -struct MMA_64x32x8_F32TF32TF32_SS_TN -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = float[16]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - float & d00, float & d01, float & d02, float & d03, - float & d04, float & d05, float & d06, float & d07, - float & d08, float & d09, float & d10, float & d11, - float & d12, float & d13, float & d14, float & d15, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %18, 0;\n" - "wgmma.mma_async.sync.aligned.m64n32k8.f32.tf32.tf32 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15}," - " %16," - " %17," - " p, %19, %20;\n" - "}\n" - : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), - "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), - "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), - "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15) - : "l"(desc_a), - "l"(desc_b), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x32x8_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// GMMA 64x32x8 TN F32+=TF32*TF32 -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -struct MMA_64x32x8_F32TF32TF32_RS_TN -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using BRegisters = uint64_t[1]; - using CRegisters = float[16]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, - uint64_t const& desc_b, - float & d00, float & d01, float & d02, float & d03, - float & d04, float & d05, float & d06, float & d07, - float & d08, float & d09, float & d10, float & d11, - float & d12, float & d13, float & d14, float & d15, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %21, 0;\n" - "wgmma.mma_async.sync.aligned.m64n32k8.f32.tf32.tf32 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15}," - "{%16, %17, %18, %19}," - " %20," - " p, %22, %23;\n" - "}\n" - : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), - "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), - "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), - "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), - "l"(desc_b), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x32x8_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x48x8 TN F32+=TF32*TF32 -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -struct MMA_64x48x8_F32TF32TF32_SS_TN -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = float[24]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - float & d00, float & d01, float & d02, float & d03, - float & d04, float & d05, float & d06, float & d07, - float & d08, float & d09, float & d10, float & d11, - float & d12, float & d13, float & d14, float & d15, - float & d16, float & d17, float & d18, float & d19, - float & d20, float & d21, float & d22, float & d23, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %26, 0;\n" - "wgmma.mma_async.sync.aligned.m64n48k8.f32.tf32.tf32 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23}," - " %24," - " %25," - " p, %27, %28;\n" - "}\n" - : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), - "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), - "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), - "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), - "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), - "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23) - : "l"(desc_a), - "l"(desc_b), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x48x8_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x48x8 TN F32+=TF32*TF32 -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -struct MMA_64x48x8_F32TF32TF32_RS_TN -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using BRegisters = uint64_t[1]; - using CRegisters = float[24]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, - uint64_t const& desc_b, - float & d00, float & d01, float & d02, float & d03, - float & d04, float & d05, float & d06, float & d07, - float & d08, float & d09, float & d10, float & d11, - float & d12, float & d13, float & d14, float & d15, - float & d16, float & d17, float & d18, float & d19, - float & d20, float & d21, float & d22, float & d23, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %29, 0;\n" - "wgmma.mma_async.sync.aligned.m64n48k8.f32.tf32.tf32 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23}," - "{%24, %25, %26, %27}," - " %28," - " p, %30, %31;\n" - "}\n" - : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), - "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), - "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), - "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), - "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), - "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), - "l"(desc_b), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x48x8_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// GMMA 64x64x8 TN F32+=TF32*TF32 -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -struct MMA_64x64x8_F32TF32TF32_SS_TN -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = float[32]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - float & d00, float & d01, float & d02, float & d03, - float & d04, float & d05, float & d06, float & d07, - float & d08, float & d09, float & d10, float & d11, - float & d12, float & d13, float & d14, float & d15, - float & d16, float & d17, float & d18, float & d19, - float & d20, float & d21, float & d22, float & d23, - float & d24, float & d25, float & d26, float & d27, - float & d28, float & d29, float & d30, float & d31, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %34, 0;\n" - "wgmma.mma_async.sync.aligned.m64n64k8.f32.tf32.tf32 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31}," - " %32," - " %33," - " p, %35, %36;\n" - "}\n" - : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), - "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), - "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), - "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), - "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), - "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), - "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), - "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31) - : "l"(desc_a), - "l"(desc_b), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x64x8_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// GMMA 64x64x8 TN F32+=TF32*TF32 -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -struct MMA_64x64x8_F32TF32TF32_RS_TN -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using BRegisters = uint64_t[1]; - using CRegisters = float[32]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, - uint64_t const& desc_b, - float & d00, float & d01, float & d02, float & d03, - float & d04, float & d05, float & d06, float & d07, - float & d08, float & d09, float & d10, float & d11, - float & d12, float & d13, float & d14, float & d15, - float & d16, float & d17, float & d18, float & d19, - float & d20, float & d21, float & d22, float & d23, - float & d24, float & d25, float & d26, float & d27, - float & d28, float & d29, float & d30, float & d31, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %37, 0;\n" - "wgmma.mma_async.sync.aligned.m64n64k8.f32.tf32.tf32 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31}," - "{%32, %33, %34, %35}," - " %36," - " p, %38, %39;\n" - "}\n" - : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), - "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), - "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), - "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), - "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), - "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), - "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), - "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), - "l"(desc_b), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x64x8_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x80x8 TN F32+=TF32*TF32 -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -struct MMA_64x80x8_F32TF32TF32_SS_TN -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = float[40]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - float & d00, float & d01, float & d02, float & d03, - float & d04, float & d05, float & d06, float & d07, - float & d08, float & d09, float & d10, float & d11, - float & d12, float & d13, float & d14, float & d15, - float & d16, float & d17, float & d18, float & d19, - float & d20, float & d21, float & d22, float & d23, - float & d24, float & d25, float & d26, float & d27, - float & d28, float & d29, float & d30, float & d31, - float & d32, float & d33, float & d34, float & d35, - float & d36, float & d37, float & d38, float & d39, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %42, 0;\n" - "wgmma.mma_async.sync.aligned.m64n80k8.f32.tf32.tf32 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39}," - " %40," - " %41," - " p, %43, %44;\n" - "}\n" - : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), - "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), - "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), - "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), - "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), - "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), - "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), - "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), - "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), - "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39) - : "l"(desc_a), - "l"(desc_b), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x80x8_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x80x8 TN F32+=TF32*TF32 -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -struct MMA_64x80x8_F32TF32TF32_RS_TN -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using BRegisters = uint64_t[1]; - using CRegisters = float[40]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, - uint64_t const& desc_b, - float & d00, float & d01, float & d02, float & d03, - float & d04, float & d05, float & d06, float & d07, - float & d08, float & d09, float & d10, float & d11, - float & d12, float & d13, float & d14, float & d15, - float & d16, float & d17, float & d18, float & d19, - float & d20, float & d21, float & d22, float & d23, - float & d24, float & d25, float & d26, float & d27, - float & d28, float & d29, float & d30, float & d31, - float & d32, float & d33, float & d34, float & d35, - float & d36, float & d37, float & d38, float & d39, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %45, 0;\n" - "wgmma.mma_async.sync.aligned.m64n80k8.f32.tf32.tf32 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39}," - "{%40, %41, %42, %43}," - " %44," - " p, %46, %47;\n" - "}\n" - : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), - "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), - "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), - "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), - "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), - "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), - "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), - "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), - "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), - "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), - "l"(desc_b), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x80x8_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// GMMA 64x96x8 TN F32+=TF32*TF32 -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -struct MMA_64x96x8_F32TF32TF32_SS_TN -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = float[48]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - float & d00, float & d01, float & d02, float & d03, - float & d04, float & d05, float & d06, float & d07, - float & d08, float & d09, float & d10, float & d11, - float & d12, float & d13, float & d14, float & d15, - float & d16, float & d17, float & d18, float & d19, - float & d20, float & d21, float & d22, float & d23, - float & d24, float & d25, float & d26, float & d27, - float & d28, float & d29, float & d30, float & d31, - float & d32, float & d33, float & d34, float & d35, - float & d36, float & d37, float & d38, float & d39, - float & d40, float & d41, float & d42, float & d43, - float & d44, float & d45, float & d46, float & d47, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %50, 0;\n" - "wgmma.mma_async.sync.aligned.m64n96k8.f32.tf32.tf32 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47}," - " %48," - " %49," - " p, %51, %52;\n" - "}\n" - : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), - "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), - "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), - "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), - "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), - "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), - "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), - "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), - "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), - "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), - "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), - "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47) - : "l"(desc_a), - "l"(desc_b), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x96x8_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// GMMA 64x96x8 TN F32+=TF32*TF32 -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -struct MMA_64x96x8_F32TF32TF32_RS_TN -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using BRegisters = uint64_t[1]; - using CRegisters = float[48]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, - uint64_t const& desc_b, - float & d00, float & d01, float & d02, float & d03, - float & d04, float & d05, float & d06, float & d07, - float & d08, float & d09, float & d10, float & d11, - float & d12, float & d13, float & d14, float & d15, - float & d16, float & d17, float & d18, float & d19, - float & d20, float & d21, float & d22, float & d23, - float & d24, float & d25, float & d26, float & d27, - float & d28, float & d29, float & d30, float & d31, - float & d32, float & d33, float & d34, float & d35, - float & d36, float & d37, float & d38, float & d39, - float & d40, float & d41, float & d42, float & d43, - float & d44, float & d45, float & d46, float & d47, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %53, 0;\n" - "wgmma.mma_async.sync.aligned.m64n96k8.f32.tf32.tf32 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47}," - "{%48, %49, %50, %51}," - " %52," - " p, %54, %55;\n" - "}\n" - : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), - "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), - "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), - "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), - "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), - "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), - "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), - "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), - "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), - "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), - "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), - "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), - "l"(desc_b), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x96x8_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x112x8 TN F32+=TF32*TF32 -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -struct MMA_64x112x8_F32TF32TF32_SS_TN -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = float[56]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - float & d00, float & d01, float & d02, float & d03, - float & d04, float & d05, float & d06, float & d07, - float & d08, float & d09, float & d10, float & d11, - float & d12, float & d13, float & d14, float & d15, - float & d16, float & d17, float & d18, float & d19, - float & d20, float & d21, float & d22, float & d23, - float & d24, float & d25, float & d26, float & d27, - float & d28, float & d29, float & d30, float & d31, - float & d32, float & d33, float & d34, float & d35, - float & d36, float & d37, float & d38, float & d39, - float & d40, float & d41, float & d42, float & d43, - float & d44, float & d45, float & d46, float & d47, - float & d48, float & d49, float & d50, float & d51, - float & d52, float & d53, float & d54, float & d55, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %58, 0;\n" - "wgmma.mma_async.sync.aligned.m64n112k8.f32.tf32.tf32 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55}," - " %56," - " %57," - " p, %59, %60;\n" - "}\n" - : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), - "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), - "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), - "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), - "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), - "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), - "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), - "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), - "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), - "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), - "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), - "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), - "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), - "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55) - : "l"(desc_a), - "l"(desc_b), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x112x8_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x112x8 TN F32+=TF32*TF32 -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -struct MMA_64x112x8_F32TF32TF32_RS_TN -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using BRegisters = uint64_t[1]; - using CRegisters = float[56]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, - uint64_t const& desc_b, - float & d00, float & d01, float & d02, float & d03, - float & d04, float & d05, float & d06, float & d07, - float & d08, float & d09, float & d10, float & d11, - float & d12, float & d13, float & d14, float & d15, - float & d16, float & d17, float & d18, float & d19, - float & d20, float & d21, float & d22, float & d23, - float & d24, float & d25, float & d26, float & d27, - float & d28, float & d29, float & d30, float & d31, - float & d32, float & d33, float & d34, float & d35, - float & d36, float & d37, float & d38, float & d39, - float & d40, float & d41, float & d42, float & d43, - float & d44, float & d45, float & d46, float & d47, - float & d48, float & d49, float & d50, float & d51, - float & d52, float & d53, float & d54, float & d55, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %61, 0;\n" - "wgmma.mma_async.sync.aligned.m64n112k8.f32.tf32.tf32 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55}," - "{%56, %57, %58, %59}," - " %60," - " p, %62, %63;\n" - "}\n" - : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), - "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), - "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), - "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), - "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), - "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), - "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), - "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), - "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), - "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), - "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), - "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), - "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), - "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), - "l"(desc_b), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x112x8_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// GMMA 64x128x8 TN F32+=TF32*TF32 -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -struct MMA_64x128x8_F32TF32TF32_SS_TN -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = float[64]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - float & d00, float & d01, float & d02, float & d03, - float & d04, float & d05, float & d06, float & d07, - float & d08, float & d09, float & d10, float & d11, - float & d12, float & d13, float & d14, float & d15, - float & d16, float & d17, float & d18, float & d19, - float & d20, float & d21, float & d22, float & d23, - float & d24, float & d25, float & d26, float & d27, - float & d28, float & d29, float & d30, float & d31, - float & d32, float & d33, float & d34, float & d35, - float & d36, float & d37, float & d38, float & d39, - float & d40, float & d41, float & d42, float & d43, - float & d44, float & d45, float & d46, float & d47, - float & d48, float & d49, float & d50, float & d51, - float & d52, float & d53, float & d54, float & d55, - float & d56, float & d57, float & d58, float & d59, - float & d60, float & d61, float & d62, float & d63, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %66, 0;\n" - "wgmma.mma_async.sync.aligned.m64n128k8.f32.tf32.tf32 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63}," - " %64," - " %65," - " p, %67, %68;\n" - "}\n" - : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), - "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), - "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), - "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), - "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), - "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), - "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), - "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), - "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), - "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), - "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), - "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), - "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), - "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), - "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), - "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63) - : "l"(desc_a), - "l"(desc_b), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x128x8_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// GMMA 64x128x8 TN F32+=TF32*TF32 -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -struct MMA_64x128x8_F32TF32TF32_RS_TN -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using BRegisters = uint64_t[1]; - using CRegisters = float[64]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, - uint64_t const& desc_b, - float & d00, float & d01, float & d02, float & d03, - float & d04, float & d05, float & d06, float & d07, - float & d08, float & d09, float & d10, float & d11, - float & d12, float & d13, float & d14, float & d15, - float & d16, float & d17, float & d18, float & d19, - float & d20, float & d21, float & d22, float & d23, - float & d24, float & d25, float & d26, float & d27, - float & d28, float & d29, float & d30, float & d31, - float & d32, float & d33, float & d34, float & d35, - float & d36, float & d37, float & d38, float & d39, - float & d40, float & d41, float & d42, float & d43, - float & d44, float & d45, float & d46, float & d47, - float & d48, float & d49, float & d50, float & d51, - float & d52, float & d53, float & d54, float & d55, - float & d56, float & d57, float & d58, float & d59, - float & d60, float & d61, float & d62, float & d63, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %69, 0;\n" - "wgmma.mma_async.sync.aligned.m64n128k8.f32.tf32.tf32 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63}," - "{%64, %65, %66, %67}," - " %68," - " p, %70, %71;\n" - "}\n" - : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), - "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), - "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), - "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), - "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), - "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), - "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), - "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), - "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), - "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), - "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), - "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), - "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), - "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), - "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), - "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), - "l"(desc_b), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x128x8_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x144x8 TN F32+=TF32*TF32 -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -struct MMA_64x144x8_F32TF32TF32_SS_TN -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = float[72]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - float & d00, float & d01, float & d02, float & d03, - float & d04, float & d05, float & d06, float & d07, - float & d08, float & d09, float & d10, float & d11, - float & d12, float & d13, float & d14, float & d15, - float & d16, float & d17, float & d18, float & d19, - float & d20, float & d21, float & d22, float & d23, - float & d24, float & d25, float & d26, float & d27, - float & d28, float & d29, float & d30, float & d31, - float & d32, float & d33, float & d34, float & d35, - float & d36, float & d37, float & d38, float & d39, - float & d40, float & d41, float & d42, float & d43, - float & d44, float & d45, float & d46, float & d47, - float & d48, float & d49, float & d50, float & d51, - float & d52, float & d53, float & d54, float & d55, - float & d56, float & d57, float & d58, float & d59, - float & d60, float & d61, float & d62, float & d63, - float & d64, float & d65, float & d66, float & d67, - float & d68, float & d69, float & d70, float & d71, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %74, 0;\n" - "wgmma.mma_async.sync.aligned.m64n144k8.f32.tf32.tf32 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71}," - " %72," - " %73," - " p, %75, %76;\n" - "}\n" - : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), - "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), - "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), - "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), - "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), - "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), - "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), - "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), - "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), - "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), - "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), - "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), - "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), - "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), - "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), - "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), - "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), - "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71) - : "l"(desc_a), - "l"(desc_b), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x144x8_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x144x8 TN F32+=TF32*TF32 -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -struct MMA_64x144x8_F32TF32TF32_RS_TN -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using BRegisters = uint64_t[1]; - using CRegisters = float[72]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, - uint64_t const& desc_b, - float & d00, float & d01, float & d02, float & d03, - float & d04, float & d05, float & d06, float & d07, - float & d08, float & d09, float & d10, float & d11, - float & d12, float & d13, float & d14, float & d15, - float & d16, float & d17, float & d18, float & d19, - float & d20, float & d21, float & d22, float & d23, - float & d24, float & d25, float & d26, float & d27, - float & d28, float & d29, float & d30, float & d31, - float & d32, float & d33, float & d34, float & d35, - float & d36, float & d37, float & d38, float & d39, - float & d40, float & d41, float & d42, float & d43, - float & d44, float & d45, float & d46, float & d47, - float & d48, float & d49, float & d50, float & d51, - float & d52, float & d53, float & d54, float & d55, - float & d56, float & d57, float & d58, float & d59, - float & d60, float & d61, float & d62, float & d63, - float & d64, float & d65, float & d66, float & d67, - float & d68, float & d69, float & d70, float & d71, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %77, 0;\n" - "wgmma.mma_async.sync.aligned.m64n144k8.f32.tf32.tf32 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71}," - "{%72, %73, %74, %75}," - " %76," - " p, %78, %79;\n" - "}\n" - : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), - "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), - "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), - "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), - "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), - "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), - "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), - "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), - "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), - "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), - "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), - "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), - "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), - "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), - "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), - "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), - "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), - "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), - "l"(desc_b), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x144x8_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x160x8 TN F32+=TF32*TF32 -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -struct MMA_64x160x8_F32TF32TF32_SS_TN -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = float[80]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - float & d00, float & d01, float & d02, float & d03, - float & d04, float & d05, float & d06, float & d07, - float & d08, float & d09, float & d10, float & d11, - float & d12, float & d13, float & d14, float & d15, - float & d16, float & d17, float & d18, float & d19, - float & d20, float & d21, float & d22, float & d23, - float & d24, float & d25, float & d26, float & d27, - float & d28, float & d29, float & d30, float & d31, - float & d32, float & d33, float & d34, float & d35, - float & d36, float & d37, float & d38, float & d39, - float & d40, float & d41, float & d42, float & d43, - float & d44, float & d45, float & d46, float & d47, - float & d48, float & d49, float & d50, float & d51, - float & d52, float & d53, float & d54, float & d55, - float & d56, float & d57, float & d58, float & d59, - float & d60, float & d61, float & d62, float & d63, - float & d64, float & d65, float & d66, float & d67, - float & d68, float & d69, float & d70, float & d71, - float & d72, float & d73, float & d74, float & d75, - float & d76, float & d77, float & d78, float & d79, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %82, 0;\n" - "wgmma.mma_async.sync.aligned.m64n160k8.f32.tf32.tf32 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79}," - " %80," - " %81," - " p, %83, %84;\n" - "}\n" - : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), - "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), - "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), - "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), - "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), - "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), - "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), - "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), - "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), - "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), - "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), - "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), - "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), - "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), - "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), - "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), - "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), - "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), - "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), - "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79) - : "l"(desc_a), - "l"(desc_b), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x160x8_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x160x8 TN F32+=TF32*TF32 -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -struct MMA_64x160x8_F32TF32TF32_RS_TN -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using BRegisters = uint64_t[1]; - using CRegisters = float[80]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, - uint64_t const& desc_b, - float & d00, float & d01, float & d02, float & d03, - float & d04, float & d05, float & d06, float & d07, - float & d08, float & d09, float & d10, float & d11, - float & d12, float & d13, float & d14, float & d15, - float & d16, float & d17, float & d18, float & d19, - float & d20, float & d21, float & d22, float & d23, - float & d24, float & d25, float & d26, float & d27, - float & d28, float & d29, float & d30, float & d31, - float & d32, float & d33, float & d34, float & d35, - float & d36, float & d37, float & d38, float & d39, - float & d40, float & d41, float & d42, float & d43, - float & d44, float & d45, float & d46, float & d47, - float & d48, float & d49, float & d50, float & d51, - float & d52, float & d53, float & d54, float & d55, - float & d56, float & d57, float & d58, float & d59, - float & d60, float & d61, float & d62, float & d63, - float & d64, float & d65, float & d66, float & d67, - float & d68, float & d69, float & d70, float & d71, - float & d72, float & d73, float & d74, float & d75, - float & d76, float & d77, float & d78, float & d79, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %85, 0;\n" - "wgmma.mma_async.sync.aligned.m64n160k8.f32.tf32.tf32 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79}," - "{%80, %81, %82, %83}," - " %84," - " p, %86, %87;\n" - "}\n" - : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), - "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), - "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), - "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), - "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), - "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), - "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), - "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), - "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), - "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), - "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), - "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), - "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), - "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), - "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), - "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), - "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), - "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), - "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), - "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), - "l"(desc_b), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x160x8_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x176x8 TN F32+=TF32*TF32 -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -struct MMA_64x176x8_F32TF32TF32_SS_TN -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = float[88]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - float & d00, float & d01, float & d02, float & d03, - float & d04, float & d05, float & d06, float & d07, - float & d08, float & d09, float & d10, float & d11, - float & d12, float & d13, float & d14, float & d15, - float & d16, float & d17, float & d18, float & d19, - float & d20, float & d21, float & d22, float & d23, - float & d24, float & d25, float & d26, float & d27, - float & d28, float & d29, float & d30, float & d31, - float & d32, float & d33, float & d34, float & d35, - float & d36, float & d37, float & d38, float & d39, - float & d40, float & d41, float & d42, float & d43, - float & d44, float & d45, float & d46, float & d47, - float & d48, float & d49, float & d50, float & d51, - float & d52, float & d53, float & d54, float & d55, - float & d56, float & d57, float & d58, float & d59, - float & d60, float & d61, float & d62, float & d63, - float & d64, float & d65, float & d66, float & d67, - float & d68, float & d69, float & d70, float & d71, - float & d72, float & d73, float & d74, float & d75, - float & d76, float & d77, float & d78, float & d79, - float & d80, float & d81, float & d82, float & d83, - float & d84, float & d85, float & d86, float & d87, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %90, 0;\n" - "wgmma.mma_async.sync.aligned.m64n176k8.f32.tf32.tf32 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87}," - " %88," - " %89," - " p, %91, %92;\n" - "}\n" - : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), - "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), - "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), - "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), - "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), - "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), - "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), - "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), - "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), - "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), - "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), - "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), - "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), - "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), - "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), - "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), - "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), - "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), - "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), - "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), - "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), - "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87) - : "l"(desc_a), - "l"(desc_b), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x176x8_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x176x8 TN F32+=TF32*TF32 -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -struct MMA_64x176x8_F32TF32TF32_RS_TN -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using BRegisters = uint64_t[1]; - using CRegisters = float[88]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, - uint64_t const& desc_b, - float & d00, float & d01, float & d02, float & d03, - float & d04, float & d05, float & d06, float & d07, - float & d08, float & d09, float & d10, float & d11, - float & d12, float & d13, float & d14, float & d15, - float & d16, float & d17, float & d18, float & d19, - float & d20, float & d21, float & d22, float & d23, - float & d24, float & d25, float & d26, float & d27, - float & d28, float & d29, float & d30, float & d31, - float & d32, float & d33, float & d34, float & d35, - float & d36, float & d37, float & d38, float & d39, - float & d40, float & d41, float & d42, float & d43, - float & d44, float & d45, float & d46, float & d47, - float & d48, float & d49, float & d50, float & d51, - float & d52, float & d53, float & d54, float & d55, - float & d56, float & d57, float & d58, float & d59, - float & d60, float & d61, float & d62, float & d63, - float & d64, float & d65, float & d66, float & d67, - float & d68, float & d69, float & d70, float & d71, - float & d72, float & d73, float & d74, float & d75, - float & d76, float & d77, float & d78, float & d79, - float & d80, float & d81, float & d82, float & d83, - float & d84, float & d85, float & d86, float & d87, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %93, 0;\n" - "wgmma.mma_async.sync.aligned.m64n176k8.f32.tf32.tf32 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87}," - "{%88, %89, %90, %91}," - " %92," - " p, %94, %95;\n" - "}\n" - : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), - "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), - "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), - "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), - "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), - "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), - "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), - "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), - "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), - "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), - "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), - "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), - "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), - "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), - "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), - "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), - "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), - "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), - "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), - "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), - "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), - "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), - "l"(desc_b), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x176x8_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// GMMA 64x192x8 TN F32+=TF32*TF32 -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -struct MMA_64x192x8_F32TF32TF32_SS_TN -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = float[96]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - float & d00, float & d01, float & d02, float & d03, - float & d04, float & d05, float & d06, float & d07, - float & d08, float & d09, float & d10, float & d11, - float & d12, float & d13, float & d14, float & d15, - float & d16, float & d17, float & d18, float & d19, - float & d20, float & d21, float & d22, float & d23, - float & d24, float & d25, float & d26, float & d27, - float & d28, float & d29, float & d30, float & d31, - float & d32, float & d33, float & d34, float & d35, - float & d36, float & d37, float & d38, float & d39, - float & d40, float & d41, float & d42, float & d43, - float & d44, float & d45, float & d46, float & d47, - float & d48, float & d49, float & d50, float & d51, - float & d52, float & d53, float & d54, float & d55, - float & d56, float & d57, float & d58, float & d59, - float & d60, float & d61, float & d62, float & d63, - float & d64, float & d65, float & d66, float & d67, - float & d68, float & d69, float & d70, float & d71, - float & d72, float & d73, float & d74, float & d75, - float & d76, float & d77, float & d78, float & d79, - float & d80, float & d81, float & d82, float & d83, - float & d84, float & d85, float & d86, float & d87, - float & d88, float & d89, float & d90, float & d91, - float & d92, float & d93, float & d94, float & d95, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %98, 0;\n" - "wgmma.mma_async.sync.aligned.m64n192k8.f32.tf32.tf32 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87, " - " %88, %89, %90, %91, %92, %93, %94, %95}," - " %96," - " %97," - " p, %99, %100;\n" - "}\n" - : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), - "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), - "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), - "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), - "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), - "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), - "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), - "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), - "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), - "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), - "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), - "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), - "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), - "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), - "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), - "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), - "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), - "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), - "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), - "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), - "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), - "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), - "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91), - "+f"(d92), "+f"(d93), "+f"(d94), "+f"(d95) - : "l"(desc_a), - "l"(desc_b), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x192x8_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// GMMA 64x192x8 TN F32+=TF32*TF32 -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -struct MMA_64x192x8_F32TF32TF32_RS_TN -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using BRegisters = uint64_t[1]; - using CRegisters = float[96]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, - uint64_t const& desc_b, - float & d00, float & d01, float & d02, float & d03, - float & d04, float & d05, float & d06, float & d07, - float & d08, float & d09, float & d10, float & d11, - float & d12, float & d13, float & d14, float & d15, - float & d16, float & d17, float & d18, float & d19, - float & d20, float & d21, float & d22, float & d23, - float & d24, float & d25, float & d26, float & d27, - float & d28, float & d29, float & d30, float & d31, - float & d32, float & d33, float & d34, float & d35, - float & d36, float & d37, float & d38, float & d39, - float & d40, float & d41, float & d42, float & d43, - float & d44, float & d45, float & d46, float & d47, - float & d48, float & d49, float & d50, float & d51, - float & d52, float & d53, float & d54, float & d55, - float & d56, float & d57, float & d58, float & d59, - float & d60, float & d61, float & d62, float & d63, - float & d64, float & d65, float & d66, float & d67, - float & d68, float & d69, float & d70, float & d71, - float & d72, float & d73, float & d74, float & d75, - float & d76, float & d77, float & d78, float & d79, - float & d80, float & d81, float & d82, float & d83, - float & d84, float & d85, float & d86, float & d87, - float & d88, float & d89, float & d90, float & d91, - float & d92, float & d93, float & d94, float & d95, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %101, 0;\n" - "wgmma.mma_async.sync.aligned.m64n192k8.f32.tf32.tf32 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87, " - " %88, %89, %90, %91, %92, %93, %94, %95}," - "{%96, %97, %98, %99}," - " %100," - " p, %102, %103;\n" - "}\n" - : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), - "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), - "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), - "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), - "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), - "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), - "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), - "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), - "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), - "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), - "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), - "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), - "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), - "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), - "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), - "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), - "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), - "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), - "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), - "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), - "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), - "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), - "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91), - "+f"(d92), "+f"(d93), "+f"(d94), "+f"(d95) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), - "l"(desc_b), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x192x8_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x208x8 TN F32+=TF32*TF32 -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -struct MMA_64x208x8_F32TF32TF32_SS_TN -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = float[104]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - float & d000, float & d001, float & d002, float & d003, - float & d004, float & d005, float & d006, float & d007, - float & d008, float & d009, float & d010, float & d011, - float & d012, float & d013, float & d014, float & d015, - float & d016, float & d017, float & d018, float & d019, - float & d020, float & d021, float & d022, float & d023, - float & d024, float & d025, float & d026, float & d027, - float & d028, float & d029, float & d030, float & d031, - float & d032, float & d033, float & d034, float & d035, - float & d036, float & d037, float & d038, float & d039, - float & d040, float & d041, float & d042, float & d043, - float & d044, float & d045, float & d046, float & d047, - float & d048, float & d049, float & d050, float & d051, - float & d052, float & d053, float & d054, float & d055, - float & d056, float & d057, float & d058, float & d059, - float & d060, float & d061, float & d062, float & d063, - float & d064, float & d065, float & d066, float & d067, - float & d068, float & d069, float & d070, float & d071, - float & d072, float & d073, float & d074, float & d075, - float & d076, float & d077, float & d078, float & d079, - float & d080, float & d081, float & d082, float & d083, - float & d084, float & d085, float & d086, float & d087, - float & d088, float & d089, float & d090, float & d091, - float & d092, float & d093, float & d094, float & d095, - float & d096, float & d097, float & d098, float & d099, - float & d100, float & d101, float & d102, float & d103, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %106, 0;\n" - "wgmma.mma_async.sync.aligned.m64n208k8.f32.tf32.tf32 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87, " - " %88, %89, %90, %91, %92, %93, %94, %95, " - " %96, %97, %98, %99, %100, %101, %102, %103}," - " %104," - " %105," - " p, %107, %108;\n" - "}\n" - : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), - "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), - "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), - "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), - "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), - "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), - "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), - "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), - "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), - "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), - "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), - "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), - "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), - "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), - "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), - "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), - "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), - "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), - "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), - "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), - "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), - "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), - "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), - "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), - "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), - "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103) - : "l"(desc_a), - "l"(desc_b), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x208x8_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x208x8 TN F32+=TF32*TF32 -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -struct MMA_64x208x8_F32TF32TF32_RS_TN -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using BRegisters = uint64_t[1]; - using CRegisters = float[104]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, - uint64_t const& desc_b, - float & d000, float & d001, float & d002, float & d003, - float & d004, float & d005, float & d006, float & d007, - float & d008, float & d009, float & d010, float & d011, - float & d012, float & d013, float & d014, float & d015, - float & d016, float & d017, float & d018, float & d019, - float & d020, float & d021, float & d022, float & d023, - float & d024, float & d025, float & d026, float & d027, - float & d028, float & d029, float & d030, float & d031, - float & d032, float & d033, float & d034, float & d035, - float & d036, float & d037, float & d038, float & d039, - float & d040, float & d041, float & d042, float & d043, - float & d044, float & d045, float & d046, float & d047, - float & d048, float & d049, float & d050, float & d051, - float & d052, float & d053, float & d054, float & d055, - float & d056, float & d057, float & d058, float & d059, - float & d060, float & d061, float & d062, float & d063, - float & d064, float & d065, float & d066, float & d067, - float & d068, float & d069, float & d070, float & d071, - float & d072, float & d073, float & d074, float & d075, - float & d076, float & d077, float & d078, float & d079, - float & d080, float & d081, float & d082, float & d083, - float & d084, float & d085, float & d086, float & d087, - float & d088, float & d089, float & d090, float & d091, - float & d092, float & d093, float & d094, float & d095, - float & d096, float & d097, float & d098, float & d099, - float & d100, float & d101, float & d102, float & d103, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %109, 0;\n" - "wgmma.mma_async.sync.aligned.m64n208k8.f32.tf32.tf32 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87, " - " %88, %89, %90, %91, %92, %93, %94, %95, " - " %96, %97, %98, %99, %100, %101, %102, %103}," - "{%104, %105, %106, %107}," - " %108," - " p, %110, %111;\n" - "}\n" - : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), - "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), - "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), - "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), - "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), - "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), - "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), - "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), - "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), - "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), - "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), - "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), - "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), - "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), - "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), - "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), - "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), - "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), - "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), - "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), - "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), - "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), - "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), - "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), - "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), - "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103) - : "r"(a000), "r"(a001), "r"(a002), "r"(a003), - "l"(desc_b), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x208x8_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x224x8 TN F32+=TF32*TF32 -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -struct MMA_64x224x8_F32TF32TF32_SS_TN -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = float[112]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - float & d000, float & d001, float & d002, float & d003, - float & d004, float & d005, float & d006, float & d007, - float & d008, float & d009, float & d010, float & d011, - float & d012, float & d013, float & d014, float & d015, - float & d016, float & d017, float & d018, float & d019, - float & d020, float & d021, float & d022, float & d023, - float & d024, float & d025, float & d026, float & d027, - float & d028, float & d029, float & d030, float & d031, - float & d032, float & d033, float & d034, float & d035, - float & d036, float & d037, float & d038, float & d039, - float & d040, float & d041, float & d042, float & d043, - float & d044, float & d045, float & d046, float & d047, - float & d048, float & d049, float & d050, float & d051, - float & d052, float & d053, float & d054, float & d055, - float & d056, float & d057, float & d058, float & d059, - float & d060, float & d061, float & d062, float & d063, - float & d064, float & d065, float & d066, float & d067, - float & d068, float & d069, float & d070, float & d071, - float & d072, float & d073, float & d074, float & d075, - float & d076, float & d077, float & d078, float & d079, - float & d080, float & d081, float & d082, float & d083, - float & d084, float & d085, float & d086, float & d087, - float & d088, float & d089, float & d090, float & d091, - float & d092, float & d093, float & d094, float & d095, - float & d096, float & d097, float & d098, float & d099, - float & d100, float & d101, float & d102, float & d103, - float & d104, float & d105, float & d106, float & d107, - float & d108, float & d109, float & d110, float & d111, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %114, 0;\n" - "wgmma.mma_async.sync.aligned.m64n224k8.f32.tf32.tf32 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87, " - " %88, %89, %90, %91, %92, %93, %94, %95, " - " %96, %97, %98, %99, %100, %101, %102, %103, " - " %104, %105, %106, %107, %108, %109, %110, %111}," - " %112," - " %113," - " p, %115, %116;\n" - "}\n" - : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), - "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), - "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), - "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), - "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), - "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), - "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), - "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), - "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), - "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), - "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), - "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), - "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), - "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), - "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), - "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), - "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), - "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), - "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), - "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), - "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), - "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), - "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), - "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), - "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), - "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), - "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), - "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111) - : "l"(desc_a), - "l"(desc_b), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x224x8_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x224x8 TN F32+=TF32*TF32 -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -struct MMA_64x224x8_F32TF32TF32_RS_TN -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using BRegisters = uint64_t[1]; - using CRegisters = float[112]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, - uint64_t const& desc_b, - float & d000, float & d001, float & d002, float & d003, - float & d004, float & d005, float & d006, float & d007, - float & d008, float & d009, float & d010, float & d011, - float & d012, float & d013, float & d014, float & d015, - float & d016, float & d017, float & d018, float & d019, - float & d020, float & d021, float & d022, float & d023, - float & d024, float & d025, float & d026, float & d027, - float & d028, float & d029, float & d030, float & d031, - float & d032, float & d033, float & d034, float & d035, - float & d036, float & d037, float & d038, float & d039, - float & d040, float & d041, float & d042, float & d043, - float & d044, float & d045, float & d046, float & d047, - float & d048, float & d049, float & d050, float & d051, - float & d052, float & d053, float & d054, float & d055, - float & d056, float & d057, float & d058, float & d059, - float & d060, float & d061, float & d062, float & d063, - float & d064, float & d065, float & d066, float & d067, - float & d068, float & d069, float & d070, float & d071, - float & d072, float & d073, float & d074, float & d075, - float & d076, float & d077, float & d078, float & d079, - float & d080, float & d081, float & d082, float & d083, - float & d084, float & d085, float & d086, float & d087, - float & d088, float & d089, float & d090, float & d091, - float & d092, float & d093, float & d094, float & d095, - float & d096, float & d097, float & d098, float & d099, - float & d100, float & d101, float & d102, float & d103, - float & d104, float & d105, float & d106, float & d107, - float & d108, float & d109, float & d110, float & d111, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %117, 0;\n" - "wgmma.mma_async.sync.aligned.m64n224k8.f32.tf32.tf32 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87, " - " %88, %89, %90, %91, %92, %93, %94, %95, " - " %96, %97, %98, %99, %100, %101, %102, %103, " - " %104, %105, %106, %107, %108, %109, %110, %111}," - "{%112, %113, %114, %115}," - " %116," - " p, %118, %119;\n" - "}\n" - : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), - "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), - "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), - "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), - "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), - "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), - "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), - "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), - "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), - "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), - "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), - "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), - "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), - "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), - "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), - "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), - "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), - "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), - "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), - "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), - "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), - "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), - "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), - "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), - "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), - "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), - "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), - "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111) - : "r"(a000), "r"(a001), "r"(a002), "r"(a003), - "l"(desc_b), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x224x8_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x240x8 TN F32+=TF32*TF32 -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -struct MMA_64x240x8_F32TF32TF32_SS_TN -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = float[120]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - float & d000, float & d001, float & d002, float & d003, - float & d004, float & d005, float & d006, float & d007, - float & d008, float & d009, float & d010, float & d011, - float & d012, float & d013, float & d014, float & d015, - float & d016, float & d017, float & d018, float & d019, - float & d020, float & d021, float & d022, float & d023, - float & d024, float & d025, float & d026, float & d027, - float & d028, float & d029, float & d030, float & d031, - float & d032, float & d033, float & d034, float & d035, - float & d036, float & d037, float & d038, float & d039, - float & d040, float & d041, float & d042, float & d043, - float & d044, float & d045, float & d046, float & d047, - float & d048, float & d049, float & d050, float & d051, - float & d052, float & d053, float & d054, float & d055, - float & d056, float & d057, float & d058, float & d059, - float & d060, float & d061, float & d062, float & d063, - float & d064, float & d065, float & d066, float & d067, - float & d068, float & d069, float & d070, float & d071, - float & d072, float & d073, float & d074, float & d075, - float & d076, float & d077, float & d078, float & d079, - float & d080, float & d081, float & d082, float & d083, - float & d084, float & d085, float & d086, float & d087, - float & d088, float & d089, float & d090, float & d091, - float & d092, float & d093, float & d094, float & d095, - float & d096, float & d097, float & d098, float & d099, - float & d100, float & d101, float & d102, float & d103, - float & d104, float & d105, float & d106, float & d107, - float & d108, float & d109, float & d110, float & d111, - float & d112, float & d113, float & d114, float & d115, - float & d116, float & d117, float & d118, float & d119, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %122, 0;\n" - "wgmma.mma_async.sync.aligned.m64n240k8.f32.tf32.tf32 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87, " - " %88, %89, %90, %91, %92, %93, %94, %95, " - " %96, %97, %98, %99, %100, %101, %102, %103, " - " %104, %105, %106, %107, %108, %109, %110, %111, " - " %112, %113, %114, %115, %116, %117, %118, %119}," - " %120," - " %121," - " p, %123, %124;\n" - "}\n" - : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), - "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), - "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), - "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), - "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), - "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), - "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), - "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), - "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), - "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), - "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), - "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), - "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), - "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), - "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), - "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), - "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), - "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), - "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), - "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), - "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), - "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), - "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), - "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), - "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), - "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), - "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), - "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), - "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), - "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119) - : "l"(desc_a), - "l"(desc_b), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x240x8_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x240x8 TN F32+=TF32*TF32 -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -struct MMA_64x240x8_F32TF32TF32_RS_TN -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using BRegisters = uint64_t[1]; - using CRegisters = float[120]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, - uint64_t const& desc_b, - float & d000, float & d001, float & d002, float & d003, - float & d004, float & d005, float & d006, float & d007, - float & d008, float & d009, float & d010, float & d011, - float & d012, float & d013, float & d014, float & d015, - float & d016, float & d017, float & d018, float & d019, - float & d020, float & d021, float & d022, float & d023, - float & d024, float & d025, float & d026, float & d027, - float & d028, float & d029, float & d030, float & d031, - float & d032, float & d033, float & d034, float & d035, - float & d036, float & d037, float & d038, float & d039, - float & d040, float & d041, float & d042, float & d043, - float & d044, float & d045, float & d046, float & d047, - float & d048, float & d049, float & d050, float & d051, - float & d052, float & d053, float & d054, float & d055, - float & d056, float & d057, float & d058, float & d059, - float & d060, float & d061, float & d062, float & d063, - float & d064, float & d065, float & d066, float & d067, - float & d068, float & d069, float & d070, float & d071, - float & d072, float & d073, float & d074, float & d075, - float & d076, float & d077, float & d078, float & d079, - float & d080, float & d081, float & d082, float & d083, - float & d084, float & d085, float & d086, float & d087, - float & d088, float & d089, float & d090, float & d091, - float & d092, float & d093, float & d094, float & d095, - float & d096, float & d097, float & d098, float & d099, - float & d100, float & d101, float & d102, float & d103, - float & d104, float & d105, float & d106, float & d107, - float & d108, float & d109, float & d110, float & d111, - float & d112, float & d113, float & d114, float & d115, - float & d116, float & d117, float & d118, float & d119, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %125, 0;\n" - "wgmma.mma_async.sync.aligned.m64n240k8.f32.tf32.tf32 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87, " - " %88, %89, %90, %91, %92, %93, %94, %95, " - " %96, %97, %98, %99, %100, %101, %102, %103, " - " %104, %105, %106, %107, %108, %109, %110, %111, " - " %112, %113, %114, %115, %116, %117, %118, %119}," - "{%120, %121, %122, %123}," - " %124," - " p, %126, %127;\n" - "}\n" - : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), - "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), - "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), - "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), - "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), - "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), - "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), - "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), - "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), - "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), - "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), - "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), - "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), - "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), - "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), - "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), - "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), - "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), - "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), - "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), - "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), - "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), - "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), - "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), - "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), - "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), - "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), - "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), - "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), - "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119) - : "r"(a000), "r"(a001), "r"(a002), "r"(a003), - "l"(desc_b), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x240x8_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// GMMA 64x256x8 TN F32+=TF32*TF32 -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -struct MMA_64x256x8_F32TF32TF32_SS_TN -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = float[128]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - float & d000, float & d001, float & d002, float & d003, - float & d004, float & d005, float & d006, float & d007, - float & d008, float & d009, float & d010, float & d011, - float & d012, float & d013, float & d014, float & d015, - float & d016, float & d017, float & d018, float & d019, - float & d020, float & d021, float & d022, float & d023, - float & d024, float & d025, float & d026, float & d027, - float & d028, float & d029, float & d030, float & d031, - float & d032, float & d033, float & d034, float & d035, - float & d036, float & d037, float & d038, float & d039, - float & d040, float & d041, float & d042, float & d043, - float & d044, float & d045, float & d046, float & d047, - float & d048, float & d049, float & d050, float & d051, - float & d052, float & d053, float & d054, float & d055, - float & d056, float & d057, float & d058, float & d059, - float & d060, float & d061, float & d062, float & d063, - float & d064, float & d065, float & d066, float & d067, - float & d068, float & d069, float & d070, float & d071, - float & d072, float & d073, float & d074, float & d075, - float & d076, float & d077, float & d078, float & d079, - float & d080, float & d081, float & d082, float & d083, - float & d084, float & d085, float & d086, float & d087, - float & d088, float & d089, float & d090, float & d091, - float & d092, float & d093, float & d094, float & d095, - float & d096, float & d097, float & d098, float & d099, - float & d100, float & d101, float & d102, float & d103, - float & d104, float & d105, float & d106, float & d107, - float & d108, float & d109, float & d110, float & d111, - float & d112, float & d113, float & d114, float & d115, - float & d116, float & d117, float & d118, float & d119, - float & d120, float & d121, float & d122, float & d123, - float & d124, float & d125, float & d126, float & d127, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %130, 0;\n" - "wgmma.mma_async.sync.aligned.m64n256k8.f32.tf32.tf32 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87, " - " %88, %89, %90, %91, %92, %93, %94, %95, " - " %96, %97, %98, %99, %100, %101, %102, %103, " - " %104, %105, %106, %107, %108, %109, %110, %111, " - " %112, %113, %114, %115, %116, %117, %118, %119, " - " %120, %121, %122, %123, %124, %125, %126, %127}," - " %128," - " %129," - " p, %131, %132;\n" - "}\n" - : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), - "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), - "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), - "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), - "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), - "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), - "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), - "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), - "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), - "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), - "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), - "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), - "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), - "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), - "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), - "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), - "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), - "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), - "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), - "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), - "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), - "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), - "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), - "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), - "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), - "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), - "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), - "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), - "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), - "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), - "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123), - "+f"(d124), "+f"(d125), "+f"(d126), "+f"(d127) - : "l"(desc_a), - "l"(desc_b), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x256x8_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// GMMA 64x256x8 TN F32+=TF32*TF32 -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -struct MMA_64x256x8_F32TF32TF32_RS_TN -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using BRegisters = uint64_t[1]; - using CRegisters = float[128]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, - uint64_t const& desc_b, - float & d000, float & d001, float & d002, float & d003, - float & d004, float & d005, float & d006, float & d007, - float & d008, float & d009, float & d010, float & d011, - float & d012, float & d013, float & d014, float & d015, - float & d016, float & d017, float & d018, float & d019, - float & d020, float & d021, float & d022, float & d023, - float & d024, float & d025, float & d026, float & d027, - float & d028, float & d029, float & d030, float & d031, - float & d032, float & d033, float & d034, float & d035, - float & d036, float & d037, float & d038, float & d039, - float & d040, float & d041, float & d042, float & d043, - float & d044, float & d045, float & d046, float & d047, - float & d048, float & d049, float & d050, float & d051, - float & d052, float & d053, float & d054, float & d055, - float & d056, float & d057, float & d058, float & d059, - float & d060, float & d061, float & d062, float & d063, - float & d064, float & d065, float & d066, float & d067, - float & d068, float & d069, float & d070, float & d071, - float & d072, float & d073, float & d074, float & d075, - float & d076, float & d077, float & d078, float & d079, - float & d080, float & d081, float & d082, float & d083, - float & d084, float & d085, float & d086, float & d087, - float & d088, float & d089, float & d090, float & d091, - float & d092, float & d093, float & d094, float & d095, - float & d096, float & d097, float & d098, float & d099, - float & d100, float & d101, float & d102, float & d103, - float & d104, float & d105, float & d106, float & d107, - float & d108, float & d109, float & d110, float & d111, - float & d112, float & d113, float & d114, float & d115, - float & d116, float & d117, float & d118, float & d119, - float & d120, float & d121, float & d122, float & d123, - float & d124, float & d125, float & d126, float & d127, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %133, 0;\n" - "wgmma.mma_async.sync.aligned.m64n256k8.f32.tf32.tf32 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87, " - " %88, %89, %90, %91, %92, %93, %94, %95, " - " %96, %97, %98, %99, %100, %101, %102, %103, " - " %104, %105, %106, %107, %108, %109, %110, %111, " - " %112, %113, %114, %115, %116, %117, %118, %119, " - " %120, %121, %122, %123, %124, %125, %126, %127}," - "{%128, %129, %130, %131}," - " %132," - " p, %134, %135;\n" - "}\n" - : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), - "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), - "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), - "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), - "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), - "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), - "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), - "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), - "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), - "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), - "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), - "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), - "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), - "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), - "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), - "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), - "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), - "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), - "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), - "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), - "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), - "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), - "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), - "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), - "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), - "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), - "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), - "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), - "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), - "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), - "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123), - "+f"(d124), "+f"(d125), "+f"(d126), "+f"(d127) - : "r"(a000), "r"(a001), "r"(a002), "r"(a003), - "l"(desc_b), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x256x8_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// GMMA 64x8x32 TN S32+=S8*S8 -struct MMA_64x8x32_S32S8S8_SS_TN -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[4]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %6, 0;\n" - "wgmma.mma_async.sync.aligned.m64n8k32.s32.s8.s8 " - "{%0, %1, %2, %3}," - " %4," - " %5," - " p;\n" - "}\n" - : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) - : "l"(desc_a), - "l"(desc_b), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x8x32_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// GMMA 64x8x32 TN S32+=S8*S8 -struct MMA_64x8x32_S32S8S8_SS_TN_SATURATE -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[4]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %6, 0;\n" - "wgmma.mma_async.sync.aligned.m64n8k32.s32.s8.s8.satfinite " - "{%0, %1, %2, %3}," - " %4," - " %5," - " p;\n" - "}\n" - : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) - : "l"(desc_a), - "l"(desc_b), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x8x32_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// GMMA 64x16x32 TN S32+=S8*S8 -struct MMA_64x16x32_S32S8S8_SS_TN -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[8]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, - uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %10, 0;\n" - "wgmma.mma_async.sync.aligned.m64n16k32.s32.s8.s8 " - "{%0, %1, %2, %3, %4, %5, %6, %7}," - " %8," - " %9," - " p;\n" - "}\n" - : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), - "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) - : "l"(desc_a), - "l"(desc_b), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x16x32_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// GMMA 64x16x32 TN S32+=S8*S8 -struct MMA_64x16x32_S32S8S8_SS_TN_SATURATE -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[8]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, - uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %10, 0;\n" - "wgmma.mma_async.sync.aligned.m64n16k32.s32.s8.s8.satfinite " - "{%0, %1, %2, %3, %4, %5, %6, %7}," - " %8," - " %9," - " p;\n" - "}\n" - : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), - "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) - : "l"(desc_a), - "l"(desc_b), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x16x32_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// GMMA 64x32x32 TN S32+=S8*S8 -struct MMA_64x32x32_S32S8S8_SS_TN -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[16]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %18, 0;\n" - "wgmma.mma_async.sync.aligned.m64n32k32.s32.s8.s8 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15}," - " %16," - " %17," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) - : "l"(desc_a), - "l"(desc_b), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x32x32_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// GMMA 64x32x32 TN S32+=S8*S8 -struct MMA_64x32x32_S32S8S8_SS_TN_SATURATE -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[16]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %18, 0;\n" - "wgmma.mma_async.sync.aligned.m64n32k32.s32.s8.s8.satfinite " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15}," - " %16," - " %17," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) - : "l"(desc_a), - "l"(desc_b), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x32x32_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x48x32 TN S32+=S8*S8 -struct MMA_64x48x32_S32S8S8_SS_TN -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[24]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %26, 0;\n" - "wgmma.mma_async.sync.aligned.m64n48k32.s32.s8.s8 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23}," - " %24," - " %25," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) - : "l"(desc_a), - "l"(desc_b), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x48x32_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x48x32 TN S32+=S8*S8 -struct MMA_64x48x32_S32S8S8_SS_TN_SATURATE -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[24]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %26, 0;\n" - "wgmma.mma_async.sync.aligned.m64n48k32.s32.s8.s8.satfinite " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23}," - " %24," - " %25," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) - : "l"(desc_a), - "l"(desc_b), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x48x32_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// GMMA 64x64x32 TN S32+=S8*S8 -struct MMA_64x64x32_S32S8S8_SS_TN -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[32]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %34, 0;\n" - "wgmma.mma_async.sync.aligned.m64n64k32.s32.s8.s8 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31}," - " %32," - " %33," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) - : "l"(desc_a), - "l"(desc_b), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x64x32_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// GMMA 64x64x32 TN S32+=S8*S8 -struct MMA_64x64x32_S32S8S8_SS_TN_SATURATE -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[32]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %34, 0;\n" - "wgmma.mma_async.sync.aligned.m64n64k32.s32.s8.s8.satfinite " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31}," - " %32," - " %33," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) - : "l"(desc_a), - "l"(desc_b), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x64x32_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x80x32 TN S32+=S8*S8 -struct MMA_64x80x32_S32S8S8_SS_TN -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[40]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %42, 0;\n" - "wgmma.mma_async.sync.aligned.m64n80k32.s32.s8.s8 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39}," - " %40," - " %41," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) - : "l"(desc_a), - "l"(desc_b), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x80x32_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x80x32 TN S32+=S8*S8 -struct MMA_64x80x32_S32S8S8_SS_TN_SATURATE -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[40]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %42, 0;\n" - "wgmma.mma_async.sync.aligned.m64n80k32.s32.s8.s8.satfinite " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39}," - " %40," - " %41," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) - : "l"(desc_a), - "l"(desc_b), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x80x32_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// GMMA 64x96x32 TN S32+=S8*S8 -struct MMA_64x96x32_S32S8S8_SS_TN -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[48]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %50, 0;\n" - "wgmma.mma_async.sync.aligned.m64n96k32.s32.s8.s8 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47}," - " %48," - " %49," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), - "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) - : "l"(desc_a), - "l"(desc_b), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x96x32_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// GMMA 64x96x32 TN S32+=S8*S8 -struct MMA_64x96x32_S32S8S8_SS_TN_SATURATE -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[48]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %50, 0;\n" - "wgmma.mma_async.sync.aligned.m64n96k32.s32.s8.s8.satfinite " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47}," - " %48," - " %49," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), - "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) - : "l"(desc_a), - "l"(desc_b), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x96x32_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x112x32 TN S32+=S8*S8 -struct MMA_64x112x32_S32S8S8_SS_TN -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[56]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, - uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, - uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %58, 0;\n" - "wgmma.mma_async.sync.aligned.m64n112k32.s32.s8.s8 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55}," - " %56," - " %57," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), - "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), - "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), - "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) - : "l"(desc_a), - "l"(desc_b), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x112x32_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x112x32 TN S32+=S8*S8 -struct MMA_64x112x32_S32S8S8_SS_TN_SATURATE -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[56]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, - uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, - uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %58, 0;\n" - "wgmma.mma_async.sync.aligned.m64n112k32.s32.s8.s8.satfinite " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55}," - " %56," - " %57," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), - "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), - "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), - "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) - : "l"(desc_a), - "l"(desc_b), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x112x32_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// GMMA 64x128x32 TN S32+=S8*S8 -struct MMA_64x128x32_S32S8S8_SS_TN -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[64]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, - uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, - uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, - uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, - uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %66, 0;\n" - "wgmma.mma_async.sync.aligned.m64n128k32.s32.s8.s8 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63}," - " %64," - " %65," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), - "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), - "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), - "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), - "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), - "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) - : "l"(desc_a), - "l"(desc_b), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x128x32_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// GMMA 64x128x32 TN S32+=S8*S8 -struct MMA_64x128x32_S32S8S8_SS_TN_SATURATE -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[64]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, - uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, - uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, - uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, - uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %66, 0;\n" - "wgmma.mma_async.sync.aligned.m64n128k32.s32.s8.s8.satfinite " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63}," - " %64," - " %65," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), - "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), - "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), - "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), - "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), - "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) - : "l"(desc_a), - "l"(desc_b), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x128x32_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x144x32 TN S32+=S8*S8 -struct MMA_64x144x32_S32S8S8_SS_TN -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[72]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, - uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, - uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, - uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, - uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, - uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, - uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %74, 0;\n" - "wgmma.mma_async.sync.aligned.m64n144k32.s32.s8.s8 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71}," - " %72," - " %73," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), - "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), - "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), - "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), - "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), - "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), - "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), - "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71) - : "l"(desc_a), - "l"(desc_b), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x144x32_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x144x32 TN S32+=S8*S8 -struct MMA_64x144x32_S32S8S8_SS_TN_SATURATE -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[72]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, - uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, - uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, - uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, - uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, - uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, - uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %74, 0;\n" - "wgmma.mma_async.sync.aligned.m64n144k32.s32.s8.s8.satfinite " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71}," - " %72," - " %73," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), - "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), - "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), - "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), - "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), - "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), - "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), - "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71) - : "l"(desc_a), - "l"(desc_b), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x144x32_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x160x32 TN S32+=S8*S8 -struct MMA_64x160x32_S32S8S8_SS_TN -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[80]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, - uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, - uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, - uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, - uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, - uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, - uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, - uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, - uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %82, 0;\n" - "wgmma.mma_async.sync.aligned.m64n160k32.s32.s8.s8 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79}," - " %80," - " %81," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), - "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), - "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), - "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), - "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), - "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), - "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), - "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), - "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), - "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79) - : "l"(desc_a), - "l"(desc_b), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x160x32_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x160x32 TN S32+=S8*S8 -struct MMA_64x160x32_S32S8S8_SS_TN_SATURATE -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[80]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, - uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, - uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, - uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, - uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, - uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, - uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, - uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, - uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %82, 0;\n" - "wgmma.mma_async.sync.aligned.m64n160k32.s32.s8.s8.satfinite " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79}," - " %80," - " %81," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), - "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), - "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), - "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), - "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), - "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), - "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), - "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), - "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), - "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79) - : "l"(desc_a), - "l"(desc_b), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x160x32_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x176x32 TN S32+=S8*S8 -struct MMA_64x176x32_S32S8S8_SS_TN -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[88]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, - uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, - uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, - uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, - uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, - uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, - uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, - uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, - uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, - uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, - uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %90, 0;\n" - "wgmma.mma_async.sync.aligned.m64n176k32.s32.s8.s8 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87}," - " %88," - " %89," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), - "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), - "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), - "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), - "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), - "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), - "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), - "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), - "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), - "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), - "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), - "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87) - : "l"(desc_a), - "l"(desc_b), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x176x32_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x176x32 TN S32+=S8*S8 -struct MMA_64x176x32_S32S8S8_SS_TN_SATURATE -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[88]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, - uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, - uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, - uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, - uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, - uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, - uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, - uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, - uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, - uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, - uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %90, 0;\n" - "wgmma.mma_async.sync.aligned.m64n176k32.s32.s8.s8.satfinite " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87}," - " %88," - " %89," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), - "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), - "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), - "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), - "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), - "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), - "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), - "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), - "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), - "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), - "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), - "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87) - : "l"(desc_a), - "l"(desc_b), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x176x32_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// GMMA 64x192x32 TN S32+=S8*S8 -struct MMA_64x192x32_S32S8S8_SS_TN -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[96]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, - uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, - uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, - uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, - uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, - uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, - uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, - uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, - uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, - uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, - uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, - uint32_t & d88, uint32_t & d89, uint32_t & d90, uint32_t & d91, - uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %98, 0;\n" - "wgmma.mma_async.sync.aligned.m64n192k32.s32.s8.s8 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87, " - " %88, %89, %90, %91, %92, %93, %94, %95}," - " %96," - " %97," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), - "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), - "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), - "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), - "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), - "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), - "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), - "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), - "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), - "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), - "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), - "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87), - "+r"(d88), "+r"(d89), "+r"(d90), "+r"(d91), - "+r"(d92), "+r"(d93), "+r"(d94), "+r"(d95) - : "l"(desc_a), - "l"(desc_b), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x192x32_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// GMMA 64x192x32 TN S32+=S8*S8 -struct MMA_64x192x32_S32S8S8_SS_TN_SATURATE -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[96]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, - uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, - uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, - uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, - uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, - uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, - uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, - uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, - uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, - uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, - uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, - uint32_t & d88, uint32_t & d89, uint32_t & d90, uint32_t & d91, - uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %98, 0;\n" - "wgmma.mma_async.sync.aligned.m64n192k32.s32.s8.s8.satfinite " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87, " - " %88, %89, %90, %91, %92, %93, %94, %95}," - " %96," - " %97," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), - "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), - "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), - "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), - "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), - "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), - "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), - "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), - "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), - "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), - "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), - "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87), - "+r"(d88), "+r"(d89), "+r"(d90), "+r"(d91), - "+r"(d92), "+r"(d93), "+r"(d94), "+r"(d95) - : "l"(desc_a), - "l"(desc_b), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x192x32_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x208x32 TN S32+=S8*S8 -struct MMA_64x208x32_S32S8S8_SS_TN -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[104]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, - uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, - uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, - uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, - uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, - uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, - uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, - uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, - uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, - uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, - uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, - uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, - uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, - uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, - uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, - uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, - uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, - uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, - uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, - uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, - uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, - uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, - uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, - uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, - uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, - uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %106, 0;\n" - "wgmma.mma_async.sync.aligned.m64n208k32.s32.s8.s8 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87, " - " %88, %89, %90, %91, %92, %93, %94, %95, " - " %96, %97, %98, %99, %100, %101, %102, %103}," - " %104," - " %105," - " p;\n" - "}\n" - : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), - "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), - "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), - "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), - "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), - "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), - "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), - "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), - "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), - "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), - "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), - "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), - "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), - "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), - "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), - "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), - "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), - "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), - "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), - "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), - "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), - "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), - "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), - "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), - "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), - "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103) - : "l"(desc_a), - "l"(desc_b), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x208x32_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x208x32 TN S32+=S8*S8 -struct MMA_64x208x32_S32S8S8_SS_TN_SATURATE -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[104]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, - uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, - uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, - uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, - uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, - uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, - uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, - uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, - uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, - uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, - uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, - uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, - uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, - uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, - uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, - uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, - uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, - uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, - uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, - uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, - uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, - uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, - uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, - uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, - uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, - uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %106, 0;\n" - "wgmma.mma_async.sync.aligned.m64n208k32.s32.s8.s8.satfinite " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87, " - " %88, %89, %90, %91, %92, %93, %94, %95, " - " %96, %97, %98, %99, %100, %101, %102, %103}," - " %104," - " %105," - " p;\n" - "}\n" - : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), - "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), - "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), - "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), - "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), - "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), - "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), - "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), - "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), - "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), - "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), - "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), - "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), - "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), - "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), - "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), - "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), - "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), - "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), - "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), - "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), - "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), - "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), - "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), - "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), - "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103) - : "l"(desc_a), - "l"(desc_b), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x208x32_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x224x32 TN S32+=S8*S8 -struct MMA_64x224x32_S32S8S8_SS_TN -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[112]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, - uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, - uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, - uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, - uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, - uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, - uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, - uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, - uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, - uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, - uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, - uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, - uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, - uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, - uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, - uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, - uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, - uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, - uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, - uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, - uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, - uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, - uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, - uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, - uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, - uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, - uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, - uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %114, 0;\n" - "wgmma.mma_async.sync.aligned.m64n224k32.s32.s8.s8 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87, " - " %88, %89, %90, %91, %92, %93, %94, %95, " - " %96, %97, %98, %99, %100, %101, %102, %103, " - " %104, %105, %106, %107, %108, %109, %110, %111}," - " %112," - " %113," - " p;\n" - "}\n" - : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), - "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), - "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), - "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), - "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), - "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), - "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), - "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), - "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), - "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), - "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), - "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), - "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), - "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), - "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), - "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), - "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), - "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), - "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), - "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), - "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), - "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), - "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), - "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), - "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), - "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), - "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), - "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111) - : "l"(desc_a), - "l"(desc_b), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x224x32_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x224x32 TN S32+=S8*S8 -struct MMA_64x224x32_S32S8S8_SS_TN_SATURATE -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[112]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, - uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, - uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, - uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, - uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, - uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, - uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, - uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, - uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, - uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, - uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, - uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, - uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, - uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, - uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, - uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, - uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, - uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, - uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, - uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, - uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, - uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, - uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, - uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, - uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, - uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, - uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, - uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %114, 0;\n" - "wgmma.mma_async.sync.aligned.m64n224k32.s32.s8.s8.satfinite " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87, " - " %88, %89, %90, %91, %92, %93, %94, %95, " - " %96, %97, %98, %99, %100, %101, %102, %103, " - " %104, %105, %106, %107, %108, %109, %110, %111}," - " %112," - " %113," - " p;\n" - "}\n" - : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), - "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), - "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), - "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), - "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), - "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), - "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), - "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), - "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), - "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), - "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), - "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), - "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), - "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), - "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), - "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), - "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), - "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), - "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), - "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), - "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), - "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), - "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), - "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), - "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), - "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), - "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), - "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111) - : "l"(desc_a), - "l"(desc_b), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x224x32_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x240x32 TN S32+=S8*S8 -struct MMA_64x240x32_S32S8S8_SS_TN -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[120]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, - uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, - uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, - uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, - uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, - uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, - uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, - uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, - uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, - uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, - uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, - uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, - uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, - uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, - uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, - uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, - uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, - uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, - uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, - uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, - uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, - uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, - uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, - uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, - uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, - uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, - uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, - uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, - uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, - uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %122, 0;\n" - "wgmma.mma_async.sync.aligned.m64n240k32.s32.s8.s8 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87, " - " %88, %89, %90, %91, %92, %93, %94, %95, " - " %96, %97, %98, %99, %100, %101, %102, %103, " - " %104, %105, %106, %107, %108, %109, %110, %111, " - " %112, %113, %114, %115, %116, %117, %118, %119}," - " %120," - " %121," - " p;\n" - "}\n" - : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), - "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), - "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), - "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), - "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), - "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), - "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), - "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), - "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), - "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), - "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), - "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), - "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), - "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), - "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), - "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), - "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), - "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), - "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), - "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), - "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), - "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), - "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), - "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), - "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), - "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), - "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), - "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), - "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), - "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119) - : "l"(desc_a), - "l"(desc_b), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x240x32_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x240x32 TN S32+=S8*S8 -struct MMA_64x240x32_S32S8S8_SS_TN_SATURATE -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[120]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, - uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, - uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, - uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, - uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, - uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, - uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, - uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, - uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, - uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, - uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, - uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, - uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, - uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, - uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, - uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, - uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, - uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, - uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, - uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, - uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, - uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, - uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, - uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, - uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, - uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, - uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, - uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, - uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, - uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %122, 0;\n" - "wgmma.mma_async.sync.aligned.m64n240k32.s32.s8.s8.satfinite " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87, " - " %88, %89, %90, %91, %92, %93, %94, %95, " - " %96, %97, %98, %99, %100, %101, %102, %103, " - " %104, %105, %106, %107, %108, %109, %110, %111, " - " %112, %113, %114, %115, %116, %117, %118, %119}," - " %120," - " %121," - " p;\n" - "}\n" - : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), - "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), - "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), - "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), - "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), - "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), - "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), - "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), - "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), - "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), - "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), - "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), - "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), - "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), - "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), - "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), - "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), - "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), - "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), - "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), - "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), - "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), - "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), - "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), - "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), - "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), - "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), - "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), - "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), - "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119) - : "l"(desc_a), - "l"(desc_b), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x240x32_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// GMMA 64x256x32 TN S32+=S8*S8 -struct MMA_64x256x32_S32S8S8_SS_TN -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[128]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, - uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, - uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, - uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, - uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, - uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, - uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, - uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, - uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, - uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, - uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, - uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, - uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, - uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, - uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, - uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, - uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, - uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, - uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, - uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, - uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, - uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, - uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, - uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, - uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, - uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, - uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, - uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, - uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, - uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, - uint32_t & d120, uint32_t & d121, uint32_t & d122, uint32_t & d123, - uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %130, 0;\n" - "wgmma.mma_async.sync.aligned.m64n256k32.s32.s8.s8 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87, " - " %88, %89, %90, %91, %92, %93, %94, %95, " - " %96, %97, %98, %99, %100, %101, %102, %103, " - " %104, %105, %106, %107, %108, %109, %110, %111, " - " %112, %113, %114, %115, %116, %117, %118, %119, " - " %120, %121, %122, %123, %124, %125, %126, %127}," - " %128," - " %129," - " p;\n" - "}\n" - : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), - "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), - "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), - "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), - "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), - "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), - "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), - "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), - "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), - "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), - "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), - "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), - "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), - "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), - "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), - "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), - "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), - "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), - "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), - "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), - "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), - "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), - "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), - "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), - "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), - "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), - "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), - "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), - "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), - "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119), - "+r"(d120), "+r"(d121), "+r"(d122), "+r"(d123), - "+r"(d124), "+r"(d125), "+r"(d126), "+r"(d127) - : "l"(desc_a), - "l"(desc_b), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x256x32_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// GMMA 64x256x32 TN S32+=S8*S8 -struct MMA_64x256x32_S32S8S8_SS_TN_SATURATE -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[128]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, - uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, - uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, - uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, - uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, - uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, - uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, - uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, - uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, - uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, - uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, - uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, - uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, - uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, - uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, - uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, - uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, - uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, - uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, - uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, - uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, - uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, - uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, - uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, - uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, - uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, - uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, - uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, - uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, - uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, - uint32_t & d120, uint32_t & d121, uint32_t & d122, uint32_t & d123, - uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %130, 0;\n" - "wgmma.mma_async.sync.aligned.m64n256k32.s32.s8.s8.satfinite " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87, " - " %88, %89, %90, %91, %92, %93, %94, %95, " - " %96, %97, %98, %99, %100, %101, %102, %103, " - " %104, %105, %106, %107, %108, %109, %110, %111, " - " %112, %113, %114, %115, %116, %117, %118, %119, " - " %120, %121, %122, %123, %124, %125, %126, %127}," - " %128," - " %129," - " p;\n" - "}\n" - : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), - "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), - "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), - "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), - "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), - "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), - "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), - "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), - "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), - "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), - "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), - "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), - "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), - "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), - "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), - "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), - "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), - "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), - "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), - "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), - "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), - "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), - "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), - "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), - "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), - "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), - "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), - "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), - "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), - "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119), - "+r"(d120), "+r"(d121), "+r"(d122), "+r"(d123), - "+r"(d124), "+r"(d125), "+r"(d126), "+r"(d127) - : "l"(desc_a), - "l"(desc_b), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x256x32_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// GMMA 64x8x32 TN S32+=S8*S8 -struct MMA_64x8x32_S32S8S8_RS_TN -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[4]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, - uint64_t const& desc_b, - uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %9, 0;\n" - "wgmma.mma_async.sync.aligned.m64n8k32.s32.s8.s8 " - "{%0, %1, %2, %3}," - "{%4, %5, %6, %7}," - " %8," - " p;\n" - "}\n" - : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) - : "r"(a0), "r"(a1), "r"(a2), "r"(a3), - "l"(desc_b), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x8x32_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// GMMA 64x8x32 TN S32+=S8*S8 -struct MMA_64x8x32_S32S8S8_RS_TN_SATURATE -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[4]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, - uint64_t const& desc_b, - uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %9, 0;\n" - "wgmma.mma_async.sync.aligned.m64n8k32.s32.s8.s8.satfinite " - "{%0, %1, %2, %3}," - "{%4, %5, %6, %7}," - " %8," - " p;\n" - "}\n" - : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) - : "r"(a0), "r"(a1), "r"(a2), "r"(a3), - "l"(desc_b), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x8x32_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// GMMA 64x16x32 TN S32+=S8*S8 -struct MMA_64x16x32_S32S8S8_RS_TN -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[8]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, - uint64_t const& desc_b, - uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, - uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %13, 0;\n" - "wgmma.mma_async.sync.aligned.m64n16k32.s32.s8.s8 " - "{%0, %1, %2, %3, %4, %5, %6, %7}," - "{%8, %9, %10, %11}," - " %12," - " p;\n" - "}\n" - : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), - "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) - : "r"(a0), "r"(a1), "r"(a2), "r"(a3), - "l"(desc_b), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x16x32_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// GMMA 64x16x32 TN S32+=S8*S8 -struct MMA_64x16x32_S32S8S8_RS_TN_SATURATE -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[8]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, - uint64_t const& desc_b, - uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, - uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %13, 0;\n" - "wgmma.mma_async.sync.aligned.m64n16k32.s32.s8.s8.satfinite " - "{%0, %1, %2, %3, %4, %5, %6, %7}," - "{%8, %9, %10, %11}," - " %12," - " p;\n" - "}\n" - : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), - "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) - : "r"(a0), "r"(a1), "r"(a2), "r"(a3), - "l"(desc_b), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x16x32_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// GMMA 64x32x32 TN S32+=S8*S8 -struct MMA_64x32x32_S32S8S8_RS_TN -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[16]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %21, 0;\n" - "wgmma.mma_async.sync.aligned.m64n32k32.s32.s8.s8 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15}," - "{%16, %17, %18, %19}," - " %20," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), - "l"(desc_b), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x32x32_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// GMMA 64x32x32 TN S32+=S8*S8 -struct MMA_64x32x32_S32S8S8_RS_TN_SATURATE -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[16]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %21, 0;\n" - "wgmma.mma_async.sync.aligned.m64n32k32.s32.s8.s8.satfinite " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15}," - "{%16, %17, %18, %19}," - " %20," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), - "l"(desc_b), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x32x32_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x48x32 TN S32+=S8*S8 -struct MMA_64x48x32_S32S8S8_RS_TN -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[24]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %29, 0;\n" - "wgmma.mma_async.sync.aligned.m64n48k32.s32.s8.s8 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23}," - "{%24, %25, %26, %27}," - " %28," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), - "l"(desc_b), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x48x32_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x48x32 TN S32+=S8*S8 -struct MMA_64x48x32_S32S8S8_RS_TN_SATURATE -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[24]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %29, 0;\n" - "wgmma.mma_async.sync.aligned.m64n48k32.s32.s8.s8.satfinite " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23}," - "{%24, %25, %26, %27}," - " %28," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), - "l"(desc_b), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x48x32_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// GMMA 64x64x32 TN S32+=S8*S8 -struct MMA_64x64x32_S32S8S8_RS_TN -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[32]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %37, 0;\n" - "wgmma.mma_async.sync.aligned.m64n64k32.s32.s8.s8 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31}," - "{%32, %33, %34, %35}," - " %36," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), - "l"(desc_b), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x64x32_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// GMMA 64x64x32 TN S32+=S8*S8 -struct MMA_64x64x32_S32S8S8_RS_TN_SATURATE -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[32]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %37, 0;\n" - "wgmma.mma_async.sync.aligned.m64n64k32.s32.s8.s8.satfinite " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31}," - "{%32, %33, %34, %35}," - " %36," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), - "l"(desc_b), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x64x32_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x80x32 TN S32+=S8*S8 -struct MMA_64x80x32_S32S8S8_RS_TN -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[40]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %45, 0;\n" - "wgmma.mma_async.sync.aligned.m64n80k32.s32.s8.s8 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39}," - "{%40, %41, %42, %43}," - " %44," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), - "l"(desc_b), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x80x32_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x80x32 TN S32+=S8*S8 -struct MMA_64x80x32_S32S8S8_RS_TN_SATURATE -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[40]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %45, 0;\n" - "wgmma.mma_async.sync.aligned.m64n80k32.s32.s8.s8.satfinite " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39}," - "{%40, %41, %42, %43}," - " %44," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), - "l"(desc_b), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x80x32_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// GMMA 64x96x32 TN S32+=S8*S8 -struct MMA_64x96x32_S32S8S8_RS_TN -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[48]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %53, 0;\n" - "wgmma.mma_async.sync.aligned.m64n96k32.s32.s8.s8 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47}," - "{%48, %49, %50, %51}," - " %52," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), - "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), - "l"(desc_b), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x96x32_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// GMMA 64x96x32 TN S32+=S8*S8 -struct MMA_64x96x32_S32S8S8_RS_TN_SATURATE -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[48]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %53, 0;\n" - "wgmma.mma_async.sync.aligned.m64n96k32.s32.s8.s8.satfinite " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47}," - "{%48, %49, %50, %51}," - " %52," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), - "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), - "l"(desc_b), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x96x32_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x112x32 TN S32+=S8*S8 -struct MMA_64x112x32_S32S8S8_RS_TN -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[56]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, - uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, - uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %61, 0;\n" - "wgmma.mma_async.sync.aligned.m64n112k32.s32.s8.s8 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55}," - "{%56, %57, %58, %59}," - " %60," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), - "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), - "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), - "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), - "l"(desc_b), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x112x32_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x112x32 TN S32+=S8*S8 -struct MMA_64x112x32_S32S8S8_RS_TN_SATURATE -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[56]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, - uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, - uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %61, 0;\n" - "wgmma.mma_async.sync.aligned.m64n112k32.s32.s8.s8.satfinite " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55}," - "{%56, %57, %58, %59}," - " %60," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), - "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), - "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), - "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), - "l"(desc_b), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x112x32_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// GMMA 64x128x32 TN S32+=S8*S8 -struct MMA_64x128x32_S32S8S8_RS_TN -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[64]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, - uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, - uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, - uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, - uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %69, 0;\n" - "wgmma.mma_async.sync.aligned.m64n128k32.s32.s8.s8 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63}," - "{%64, %65, %66, %67}," - " %68," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), - "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), - "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), - "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), - "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), - "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), - "l"(desc_b), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x128x32_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// GMMA 64x128x32 TN S32+=S8*S8 -struct MMA_64x128x32_S32S8S8_RS_TN_SATURATE -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[64]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, - uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, - uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, - uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, - uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %69, 0;\n" - "wgmma.mma_async.sync.aligned.m64n128k32.s32.s8.s8.satfinite " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63}," - "{%64, %65, %66, %67}," - " %68," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), - "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), - "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), - "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), - "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), - "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), - "l"(desc_b), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x128x32_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x144x32 TN S32+=S8*S8 -struct MMA_64x144x32_S32S8S8_RS_TN -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[72]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, - uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, - uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, - uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, - uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, - uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, - uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %77, 0;\n" - "wgmma.mma_async.sync.aligned.m64n144k32.s32.s8.s8 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71}," - "{%72, %73, %74, %75}," - " %76," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), - "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), - "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), - "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), - "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), - "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), - "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), - "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), - "l"(desc_b), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x144x32_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x144x32 TN S32+=S8*S8 -struct MMA_64x144x32_S32S8S8_RS_TN_SATURATE -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[72]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, - uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, - uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, - uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, - uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, - uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, - uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %77, 0;\n" - "wgmma.mma_async.sync.aligned.m64n144k32.s32.s8.s8.satfinite " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71}," - "{%72, %73, %74, %75}," - " %76," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), - "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), - "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), - "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), - "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), - "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), - "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), - "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), - "l"(desc_b), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x144x32_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x160x32 TN S32+=S8*S8 -struct MMA_64x160x32_S32S8S8_RS_TN -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[80]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, - uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, - uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, - uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, - uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, - uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, - uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, - uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, - uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %85, 0;\n" - "wgmma.mma_async.sync.aligned.m64n160k32.s32.s8.s8 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79}," - "{%80, %81, %82, %83}," - " %84," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), - "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), - "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), - "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), - "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), - "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), - "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), - "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), - "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), - "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), - "l"(desc_b), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x160x32_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x160x32 TN S32+=S8*S8 -struct MMA_64x160x32_S32S8S8_RS_TN_SATURATE -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[80]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, - uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, - uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, - uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, - uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, - uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, - uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, - uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, - uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %85, 0;\n" - "wgmma.mma_async.sync.aligned.m64n160k32.s32.s8.s8.satfinite " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79}," - "{%80, %81, %82, %83}," - " %84," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), - "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), - "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), - "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), - "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), - "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), - "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), - "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), - "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), - "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), - "l"(desc_b), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x160x32_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x176x32 TN S32+=S8*S8 -struct MMA_64x176x32_S32S8S8_RS_TN -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[88]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, - uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, - uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, - uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, - uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, - uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, - uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, - uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, - uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, - uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, - uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %93, 0;\n" - "wgmma.mma_async.sync.aligned.m64n176k32.s32.s8.s8 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87}," - "{%88, %89, %90, %91}," - " %92," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), - "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), - "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), - "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), - "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), - "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), - "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), - "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), - "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), - "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), - "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), - "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), - "l"(desc_b), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x176x32_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x176x32 TN S32+=S8*S8 -struct MMA_64x176x32_S32S8S8_RS_TN_SATURATE -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[88]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, - uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, - uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, - uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, - uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, - uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, - uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, - uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, - uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, - uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, - uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %93, 0;\n" - "wgmma.mma_async.sync.aligned.m64n176k32.s32.s8.s8.satfinite " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87}," - "{%88, %89, %90, %91}," - " %92," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), - "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), - "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), - "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), - "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), - "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), - "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), - "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), - "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), - "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), - "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), - "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), - "l"(desc_b), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x176x32_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// GMMA 64x192x32 TN S32+=S8*S8 -struct MMA_64x192x32_S32S8S8_RS_TN -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[96]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, - uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, - uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, - uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, - uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, - uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, - uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, - uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, - uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, - uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, - uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, - uint32_t & d88, uint32_t & d89, uint32_t & d90, uint32_t & d91, - uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %101, 0;\n" - "wgmma.mma_async.sync.aligned.m64n192k32.s32.s8.s8 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87, " - " %88, %89, %90, %91, %92, %93, %94, %95}," - "{%96, %97, %98, %99}," - " %100," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), - "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), - "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), - "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), - "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), - "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), - "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), - "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), - "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), - "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), - "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), - "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87), - "+r"(d88), "+r"(d89), "+r"(d90), "+r"(d91), - "+r"(d92), "+r"(d93), "+r"(d94), "+r"(d95) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), - "l"(desc_b), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x192x32_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// GMMA 64x192x32 TN S32+=S8*S8 -struct MMA_64x192x32_S32S8S8_RS_TN_SATURATE -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[96]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, - uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, - uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, - uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, - uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, - uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, - uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, - uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, - uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, - uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, - uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, - uint32_t & d88, uint32_t & d89, uint32_t & d90, uint32_t & d91, - uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %101, 0;\n" - "wgmma.mma_async.sync.aligned.m64n192k32.s32.s8.s8.satfinite " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87, " - " %88, %89, %90, %91, %92, %93, %94, %95}," - "{%96, %97, %98, %99}," - " %100," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), - "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), - "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), - "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), - "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), - "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), - "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), - "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), - "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), - "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), - "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), - "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87), - "+r"(d88), "+r"(d89), "+r"(d90), "+r"(d91), - "+r"(d92), "+r"(d93), "+r"(d94), "+r"(d95) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), - "l"(desc_b), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x192x32_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x208x32 TN S32+=S8*S8 -struct MMA_64x208x32_S32S8S8_RS_TN -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[104]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, - uint64_t const& desc_b, - uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, - uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, - uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, - uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, - uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, - uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, - uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, - uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, - uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, - uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, - uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, - uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, - uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, - uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, - uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, - uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, - uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, - uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, - uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, - uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, - uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, - uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, - uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, - uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, - uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, - uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %109, 0;\n" - "wgmma.mma_async.sync.aligned.m64n208k32.s32.s8.s8 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87, " - " %88, %89, %90, %91, %92, %93, %94, %95, " - " %96, %97, %98, %99, %100, %101, %102, %103}," - "{%104, %105, %106, %107}," - " %108," - " p;\n" - "}\n" - : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), - "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), - "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), - "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), - "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), - "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), - "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), - "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), - "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), - "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), - "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), - "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), - "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), - "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), - "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), - "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), - "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), - "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), - "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), - "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), - "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), - "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), - "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), - "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), - "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), - "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103) - : "r"(a000), "r"(a001), "r"(a002), "r"(a003), - "l"(desc_b), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x208x32_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x208x32 TN S32+=S8*S8 -struct MMA_64x208x32_S32S8S8_RS_TN_SATURATE -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[104]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, - uint64_t const& desc_b, - uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, - uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, - uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, - uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, - uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, - uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, - uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, - uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, - uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, - uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, - uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, - uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, - uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, - uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, - uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, - uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, - uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, - uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, - uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, - uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, - uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, - uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, - uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, - uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, - uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, - uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %109, 0;\n" - "wgmma.mma_async.sync.aligned.m64n208k32.s32.s8.s8.satfinite " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87, " - " %88, %89, %90, %91, %92, %93, %94, %95, " - " %96, %97, %98, %99, %100, %101, %102, %103}," - "{%104, %105, %106, %107}," - " %108," - " p;\n" - "}\n" - : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), - "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), - "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), - "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), - "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), - "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), - "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), - "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), - "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), - "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), - "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), - "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), - "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), - "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), - "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), - "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), - "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), - "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), - "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), - "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), - "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), - "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), - "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), - "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), - "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), - "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103) - : "r"(a000), "r"(a001), "r"(a002), "r"(a003), - "l"(desc_b), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x208x32_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x224x32 TN S32+=S8*S8 -struct MMA_64x224x32_S32S8S8_RS_TN -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[112]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, - uint64_t const& desc_b, - uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, - uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, - uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, - uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, - uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, - uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, - uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, - uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, - uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, - uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, - uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, - uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, - uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, - uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, - uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, - uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, - uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, - uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, - uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, - uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, - uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, - uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, - uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, - uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, - uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, - uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, - uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, - uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %117, 0;\n" - "wgmma.mma_async.sync.aligned.m64n224k32.s32.s8.s8 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87, " - " %88, %89, %90, %91, %92, %93, %94, %95, " - " %96, %97, %98, %99, %100, %101, %102, %103, " - " %104, %105, %106, %107, %108, %109, %110, %111}," - "{%112, %113, %114, %115}," - " %116," - " p;\n" - "}\n" - : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), - "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), - "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), - "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), - "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), - "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), - "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), - "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), - "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), - "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), - "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), - "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), - "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), - "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), - "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), - "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), - "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), - "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), - "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), - "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), - "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), - "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), - "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), - "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), - "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), - "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), - "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), - "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111) - : "r"(a000), "r"(a001), "r"(a002), "r"(a003), - "l"(desc_b), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x224x32_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x224x32 TN S32+=S8*S8 -struct MMA_64x224x32_S32S8S8_RS_TN_SATURATE -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[112]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, - uint64_t const& desc_b, - uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, - uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, - uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, - uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, - uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, - uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, - uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, - uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, - uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, - uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, - uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, - uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, - uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, - uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, - uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, - uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, - uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, - uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, - uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, - uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, - uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, - uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, - uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, - uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, - uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, - uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, - uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, - uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %117, 0;\n" - "wgmma.mma_async.sync.aligned.m64n224k32.s32.s8.s8.satfinite " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87, " - " %88, %89, %90, %91, %92, %93, %94, %95, " - " %96, %97, %98, %99, %100, %101, %102, %103, " - " %104, %105, %106, %107, %108, %109, %110, %111}," - "{%112, %113, %114, %115}," - " %116," - " p;\n" - "}\n" - : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), - "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), - "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), - "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), - "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), - "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), - "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), - "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), - "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), - "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), - "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), - "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), - "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), - "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), - "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), - "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), - "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), - "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), - "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), - "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), - "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), - "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), - "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), - "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), - "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), - "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), - "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), - "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111) - : "r"(a000), "r"(a001), "r"(a002), "r"(a003), - "l"(desc_b), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x224x32_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x240x32 TN S32+=S8*S8 -struct MMA_64x240x32_S32S8S8_RS_TN -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[120]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, - uint64_t const& desc_b, - uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, - uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, - uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, - uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, - uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, - uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, - uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, - uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, - uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, - uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, - uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, - uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, - uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, - uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, - uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, - uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, - uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, - uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, - uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, - uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, - uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, - uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, - uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, - uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, - uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, - uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, - uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, - uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, - uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, - uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %125, 0;\n" - "wgmma.mma_async.sync.aligned.m64n240k32.s32.s8.s8 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87, " - " %88, %89, %90, %91, %92, %93, %94, %95, " - " %96, %97, %98, %99, %100, %101, %102, %103, " - " %104, %105, %106, %107, %108, %109, %110, %111, " - " %112, %113, %114, %115, %116, %117, %118, %119}," - "{%120, %121, %122, %123}," - " %124," - " p;\n" - "}\n" - : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), - "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), - "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), - "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), - "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), - "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), - "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), - "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), - "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), - "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), - "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), - "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), - "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), - "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), - "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), - "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), - "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), - "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), - "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), - "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), - "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), - "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), - "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), - "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), - "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), - "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), - "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), - "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), - "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), - "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119) - : "r"(a000), "r"(a001), "r"(a002), "r"(a003), - "l"(desc_b), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x240x32_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x240x32 TN S32+=S8*S8 -struct MMA_64x240x32_S32S8S8_RS_TN_SATURATE -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[120]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, - uint64_t const& desc_b, - uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, - uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, - uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, - uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, - uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, - uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, - uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, - uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, - uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, - uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, - uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, - uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, - uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, - uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, - uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, - uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, - uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, - uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, - uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, - uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, - uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, - uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, - uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, - uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, - uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, - uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, - uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, - uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, - uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, - uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %125, 0;\n" - "wgmma.mma_async.sync.aligned.m64n240k32.s32.s8.s8.satfinite " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87, " - " %88, %89, %90, %91, %92, %93, %94, %95, " - " %96, %97, %98, %99, %100, %101, %102, %103, " - " %104, %105, %106, %107, %108, %109, %110, %111, " - " %112, %113, %114, %115, %116, %117, %118, %119}," - "{%120, %121, %122, %123}," - " %124," - " p;\n" - "}\n" - : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), - "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), - "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), - "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), - "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), - "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), - "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), - "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), - "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), - "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), - "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), - "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), - "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), - "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), - "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), - "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), - "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), - "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), - "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), - "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), - "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), - "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), - "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), - "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), - "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), - "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), - "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), - "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), - "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), - "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119) - : "r"(a000), "r"(a001), "r"(a002), "r"(a003), - "l"(desc_b), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x240x32_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// GMMA 64x256x32 TN S32+=S8*S8 -struct MMA_64x256x32_S32S8S8_RS_TN -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[128]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, - uint64_t const& desc_b, - uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, - uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, - uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, - uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, - uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, - uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, - uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, - uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, - uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, - uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, - uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, - uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, - uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, - uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, - uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, - uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, - uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, - uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, - uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, - uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, - uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, - uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, - uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, - uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, - uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, - uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, - uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, - uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, - uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, - uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, - uint32_t & d120, uint32_t & d121, uint32_t & d122, uint32_t & d123, - uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %133, 0;\n" - "wgmma.mma_async.sync.aligned.m64n256k32.s32.s8.s8 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87, " - " %88, %89, %90, %91, %92, %93, %94, %95, " - " %96, %97, %98, %99, %100, %101, %102, %103, " - " %104, %105, %106, %107, %108, %109, %110, %111, " - " %112, %113, %114, %115, %116, %117, %118, %119, " - " %120, %121, %122, %123, %124, %125, %126, %127}," - "{%128, %129, %130, %131}," - " %132," - " p;\n" - "}\n" - : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), - "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), - "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), - "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), - "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), - "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), - "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), - "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), - "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), - "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), - "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), - "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), - "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), - "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), - "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), - "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), - "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), - "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), - "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), - "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), - "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), - "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), - "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), - "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), - "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), - "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), - "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), - "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), - "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), - "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119), - "+r"(d120), "+r"(d121), "+r"(d122), "+r"(d123), - "+r"(d124), "+r"(d125), "+r"(d126), "+r"(d127) - : "r"(a000), "r"(a001), "r"(a002), "r"(a003), - "l"(desc_b), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x256x32_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// GMMA 64x256x32 TN S32+=S8*S8 -struct MMA_64x256x32_S32S8S8_RS_TN_SATURATE -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[128]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, - uint64_t const& desc_b, - uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, - uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, - uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, - uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, - uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, - uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, - uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, - uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, - uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, - uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, - uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, - uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, - uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, - uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, - uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, - uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, - uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, - uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, - uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, - uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, - uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, - uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, - uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, - uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, - uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, - uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, - uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, - uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, - uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, - uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, - uint32_t & d120, uint32_t & d121, uint32_t & d122, uint32_t & d123, - uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %133, 0;\n" - "wgmma.mma_async.sync.aligned.m64n256k32.s32.s8.s8.satfinite " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87, " - " %88, %89, %90, %91, %92, %93, %94, %95, " - " %96, %97, %98, %99, %100, %101, %102, %103, " - " %104, %105, %106, %107, %108, %109, %110, %111, " - " %112, %113, %114, %115, %116, %117, %118, %119, " - " %120, %121, %122, %123, %124, %125, %126, %127}," - "{%128, %129, %130, %131}," - " %132," - " p;\n" - "}\n" - : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), - "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), - "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), - "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), - "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), - "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), - "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), - "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), - "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), - "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), - "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), - "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), - "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), - "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), - "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), - "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), - "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), - "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), - "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), - "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), - "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), - "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), - "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), - "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), - "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), - "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), - "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), - "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), - "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), - "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119), - "+r"(d120), "+r"(d121), "+r"(d122), "+r"(d123), - "+r"(d124), "+r"(d125), "+r"(d126), "+r"(d127) - : "r"(a000), "r"(a001), "r"(a002), "r"(a003), - "l"(desc_b), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x256x32_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// GMMA 64x8x32 TN S32+=S8*U8 -struct MMA_64x8x32_S32S8U8_SS_TN -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[4]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %6, 0;\n" - "wgmma.mma_async.sync.aligned.m64n8k32.s32.s8.u8 " - "{%0, %1, %2, %3}," - " %4," - " %5," - " p;\n" - "}\n" - : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) - : "l"(desc_a), - "l"(desc_b), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x8x32_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// GMMA 64x8x32 TN S32+=S8*U8 -struct MMA_64x8x32_S32S8U8_SS_TN_SATURATE -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[4]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %6, 0;\n" - "wgmma.mma_async.sync.aligned.m64n8k32.s32.s8.u8.satfinite " - "{%0, %1, %2, %3}," - " %4," - " %5," - " p;\n" - "}\n" - : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) - : "l"(desc_a), - "l"(desc_b), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x8x32_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// GMMA 64x16x32 TN S32+=S8*U8 -struct MMA_64x16x32_S32S8U8_SS_TN -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[8]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, - uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %10, 0;\n" - "wgmma.mma_async.sync.aligned.m64n16k32.s32.s8.u8 " - "{%0, %1, %2, %3, %4, %5, %6, %7}," - " %8," - " %9," - " p;\n" - "}\n" - : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), - "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) - : "l"(desc_a), - "l"(desc_b), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x16x32_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// GMMA 64x16x32 TN S32+=S8*U8 -struct MMA_64x16x32_S32S8U8_SS_TN_SATURATE -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[8]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, - uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %10, 0;\n" - "wgmma.mma_async.sync.aligned.m64n16k32.s32.s8.u8.satfinite " - "{%0, %1, %2, %3, %4, %5, %6, %7}," - " %8," - " %9," - " p;\n" - "}\n" - : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), - "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) - : "l"(desc_a), - "l"(desc_b), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x16x32_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// GMMA 64x32x32 TN S32+=S8*U8 -struct MMA_64x32x32_S32S8U8_SS_TN -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[16]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %18, 0;\n" - "wgmma.mma_async.sync.aligned.m64n32k32.s32.s8.u8 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15}," - " %16," - " %17," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) - : "l"(desc_a), - "l"(desc_b), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x32x32_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// GMMA 64x32x32 TN S32+=S8*U8 -struct MMA_64x32x32_S32S8U8_SS_TN_SATURATE -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[16]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %18, 0;\n" - "wgmma.mma_async.sync.aligned.m64n32k32.s32.s8.u8.satfinite " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15}," - " %16," - " %17," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) - : "l"(desc_a), - "l"(desc_b), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x32x32_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x48x32 TN S32+=S8*U8 -struct MMA_64x48x32_S32S8U8_SS_TN -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[24]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %26, 0;\n" - "wgmma.mma_async.sync.aligned.m64n48k32.s32.s8.u8 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23}," - " %24," - " %25," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) - : "l"(desc_a), - "l"(desc_b), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x48x32_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x48x32 TN S32+=S8*U8 -struct MMA_64x48x32_S32S8U8_SS_TN_SATURATE -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[24]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %26, 0;\n" - "wgmma.mma_async.sync.aligned.m64n48k32.s32.s8.u8.satfinite " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23}," - " %24," - " %25," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) - : "l"(desc_a), - "l"(desc_b), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x48x32_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// GMMA 64x64x32 TN S32+=S8*U8 -struct MMA_64x64x32_S32S8U8_SS_TN -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[32]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %34, 0;\n" - "wgmma.mma_async.sync.aligned.m64n64k32.s32.s8.u8 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31}," - " %32," - " %33," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) - : "l"(desc_a), - "l"(desc_b), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x64x32_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// GMMA 64x64x32 TN S32+=S8*U8 -struct MMA_64x64x32_S32S8U8_SS_TN_SATURATE -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[32]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %34, 0;\n" - "wgmma.mma_async.sync.aligned.m64n64k32.s32.s8.u8.satfinite " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31}," - " %32," - " %33," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) - : "l"(desc_a), - "l"(desc_b), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x64x32_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x80x32 TN S32+=S8*U8 -struct MMA_64x80x32_S32S8U8_SS_TN -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[40]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %42, 0;\n" - "wgmma.mma_async.sync.aligned.m64n80k32.s32.s8.u8 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39}," - " %40," - " %41," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) - : "l"(desc_a), - "l"(desc_b), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x80x32_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x80x32 TN S32+=S8*U8 -struct MMA_64x80x32_S32S8U8_SS_TN_SATURATE -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[40]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %42, 0;\n" - "wgmma.mma_async.sync.aligned.m64n80k32.s32.s8.u8.satfinite " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39}," - " %40," - " %41," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) - : "l"(desc_a), - "l"(desc_b), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x80x32_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// GMMA 64x96x32 TN S32+=S8*U8 -struct MMA_64x96x32_S32S8U8_SS_TN -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[48]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %50, 0;\n" - "wgmma.mma_async.sync.aligned.m64n96k32.s32.s8.u8 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47}," - " %48," - " %49," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), - "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) - : "l"(desc_a), - "l"(desc_b), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x96x32_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// GMMA 64x96x32 TN S32+=S8*U8 -struct MMA_64x96x32_S32S8U8_SS_TN_SATURATE -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[48]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %50, 0;\n" - "wgmma.mma_async.sync.aligned.m64n96k32.s32.s8.u8.satfinite " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47}," - " %48," - " %49," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), - "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) - : "l"(desc_a), - "l"(desc_b), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x96x32_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x112x32 TN S32+=S8*U8 -struct MMA_64x112x32_S32S8U8_SS_TN -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[56]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, - uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, - uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %58, 0;\n" - "wgmma.mma_async.sync.aligned.m64n112k32.s32.s8.u8 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55}," - " %56," - " %57," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), - "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), - "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), - "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) - : "l"(desc_a), - "l"(desc_b), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x112x32_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x112x32 TN S32+=S8*U8 -struct MMA_64x112x32_S32S8U8_SS_TN_SATURATE -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[56]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, - uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, - uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %58, 0;\n" - "wgmma.mma_async.sync.aligned.m64n112k32.s32.s8.u8.satfinite " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55}," - " %56," - " %57," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), - "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), - "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), - "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) - : "l"(desc_a), - "l"(desc_b), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x112x32_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// GMMA 64x128x32 TN S32+=S8*U8 -struct MMA_64x128x32_S32S8U8_SS_TN -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[64]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, - uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, - uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, - uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, - uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %66, 0;\n" - "wgmma.mma_async.sync.aligned.m64n128k32.s32.s8.u8 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63}," - " %64," - " %65," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), - "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), - "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), - "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), - "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), - "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) - : "l"(desc_a), - "l"(desc_b), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x128x32_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// GMMA 64x128x32 TN S32+=S8*U8 -struct MMA_64x128x32_S32S8U8_SS_TN_SATURATE -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[64]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, - uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, - uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, - uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, - uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %66, 0;\n" - "wgmma.mma_async.sync.aligned.m64n128k32.s32.s8.u8.satfinite " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63}," - " %64," - " %65," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), - "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), - "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), - "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), - "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), - "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) - : "l"(desc_a), - "l"(desc_b), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x128x32_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x144x32 TN S32+=S8*U8 -struct MMA_64x144x32_S32S8U8_SS_TN -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[72]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, - uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, - uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, - uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, - uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, - uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, - uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %74, 0;\n" - "wgmma.mma_async.sync.aligned.m64n144k32.s32.s8.u8 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71}," - " %72," - " %73," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), - "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), - "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), - "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), - "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), - "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), - "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), - "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71) - : "l"(desc_a), - "l"(desc_b), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x144x32_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x144x32 TN S32+=S8*U8 -struct MMA_64x144x32_S32S8U8_SS_TN_SATURATE -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[72]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, - uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, - uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, - uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, - uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, - uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, - uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %74, 0;\n" - "wgmma.mma_async.sync.aligned.m64n144k32.s32.s8.u8.satfinite " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71}," - " %72," - " %73," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), - "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), - "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), - "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), - "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), - "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), - "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), - "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71) - : "l"(desc_a), - "l"(desc_b), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x144x32_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x160x32 TN S32+=S8*U8 -struct MMA_64x160x32_S32S8U8_SS_TN -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[80]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, - uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, - uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, - uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, - uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, - uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, - uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, - uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, - uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %82, 0;\n" - "wgmma.mma_async.sync.aligned.m64n160k32.s32.s8.u8 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79}," - " %80," - " %81," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), - "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), - "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), - "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), - "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), - "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), - "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), - "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), - "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), - "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79) - : "l"(desc_a), - "l"(desc_b), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x160x32_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x160x32 TN S32+=S8*U8 -struct MMA_64x160x32_S32S8U8_SS_TN_SATURATE -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[80]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, - uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, - uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, - uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, - uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, - uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, - uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, - uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, - uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %82, 0;\n" - "wgmma.mma_async.sync.aligned.m64n160k32.s32.s8.u8.satfinite " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79}," - " %80," - " %81," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), - "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), - "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), - "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), - "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), - "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), - "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), - "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), - "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), - "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79) - : "l"(desc_a), - "l"(desc_b), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x160x32_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x176x32 TN S32+=S8*U8 -struct MMA_64x176x32_S32S8U8_SS_TN -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[88]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, - uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, - uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, - uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, - uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, - uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, - uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, - uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, - uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, - uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, - uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %90, 0;\n" - "wgmma.mma_async.sync.aligned.m64n176k32.s32.s8.u8 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87}," - " %88," - " %89," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), - "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), - "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), - "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), - "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), - "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), - "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), - "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), - "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), - "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), - "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), - "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87) - : "l"(desc_a), - "l"(desc_b), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x176x32_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x176x32 TN S32+=S8*U8 -struct MMA_64x176x32_S32S8U8_SS_TN_SATURATE -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[88]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, - uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, - uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, - uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, - uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, - uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, - uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, - uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, - uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, - uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, - uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %90, 0;\n" - "wgmma.mma_async.sync.aligned.m64n176k32.s32.s8.u8.satfinite " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87}," - " %88," - " %89," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), - "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), - "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), - "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), - "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), - "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), - "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), - "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), - "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), - "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), - "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), - "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87) - : "l"(desc_a), - "l"(desc_b), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x176x32_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// GMMA 64x192x32 TN S32+=S8*U8 -struct MMA_64x192x32_S32S8U8_SS_TN -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[96]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, - uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, - uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, - uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, - uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, - uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, - uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, - uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, - uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, - uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, - uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, - uint32_t & d88, uint32_t & d89, uint32_t & d90, uint32_t & d91, - uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %98, 0;\n" - "wgmma.mma_async.sync.aligned.m64n192k32.s32.s8.u8 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87, " - " %88, %89, %90, %91, %92, %93, %94, %95}," - " %96," - " %97," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), - "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), - "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), - "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), - "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), - "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), - "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), - "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), - "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), - "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), - "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), - "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87), - "+r"(d88), "+r"(d89), "+r"(d90), "+r"(d91), - "+r"(d92), "+r"(d93), "+r"(d94), "+r"(d95) - : "l"(desc_a), - "l"(desc_b), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x192x32_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// GMMA 64x192x32 TN S32+=S8*U8 -struct MMA_64x192x32_S32S8U8_SS_TN_SATURATE -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[96]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, - uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, - uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, - uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, - uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, - uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, - uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, - uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, - uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, - uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, - uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, - uint32_t & d88, uint32_t & d89, uint32_t & d90, uint32_t & d91, - uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %98, 0;\n" - "wgmma.mma_async.sync.aligned.m64n192k32.s32.s8.u8.satfinite " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87, " - " %88, %89, %90, %91, %92, %93, %94, %95}," - " %96," - " %97," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), - "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), - "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), - "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), - "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), - "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), - "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), - "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), - "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), - "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), - "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), - "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87), - "+r"(d88), "+r"(d89), "+r"(d90), "+r"(d91), - "+r"(d92), "+r"(d93), "+r"(d94), "+r"(d95) - : "l"(desc_a), - "l"(desc_b), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x192x32_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x208x32 TN S32+=S8*U8 -struct MMA_64x208x32_S32S8U8_SS_TN -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[104]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, - uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, - uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, - uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, - uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, - uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, - uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, - uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, - uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, - uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, - uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, - uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, - uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, - uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, - uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, - uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, - uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, - uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, - uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, - uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, - uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, - uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, - uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, - uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, - uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, - uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %106, 0;\n" - "wgmma.mma_async.sync.aligned.m64n208k32.s32.s8.u8 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87, " - " %88, %89, %90, %91, %92, %93, %94, %95, " - " %96, %97, %98, %99, %100, %101, %102, %103}," - " %104," - " %105," - " p;\n" - "}\n" - : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), - "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), - "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), - "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), - "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), - "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), - "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), - "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), - "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), - "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), - "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), - "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), - "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), - "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), - "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), - "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), - "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), - "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), - "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), - "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), - "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), - "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), - "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), - "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), - "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), - "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103) - : "l"(desc_a), - "l"(desc_b), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x208x32_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x208x32 TN S32+=S8*U8 -struct MMA_64x208x32_S32S8U8_SS_TN_SATURATE -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[104]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, - uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, - uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, - uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, - uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, - uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, - uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, - uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, - uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, - uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, - uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, - uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, - uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, - uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, - uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, - uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, - uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, - uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, - uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, - uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, - uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, - uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, - uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, - uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, - uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, - uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %106, 0;\n" - "wgmma.mma_async.sync.aligned.m64n208k32.s32.s8.u8.satfinite " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87, " - " %88, %89, %90, %91, %92, %93, %94, %95, " - " %96, %97, %98, %99, %100, %101, %102, %103}," - " %104," - " %105," - " p;\n" - "}\n" - : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), - "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), - "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), - "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), - "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), - "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), - "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), - "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), - "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), - "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), - "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), - "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), - "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), - "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), - "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), - "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), - "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), - "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), - "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), - "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), - "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), - "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), - "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), - "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), - "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), - "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103) - : "l"(desc_a), - "l"(desc_b), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x208x32_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x224x32 TN S32+=S8*U8 -struct MMA_64x224x32_S32S8U8_SS_TN -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[112]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, - uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, - uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, - uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, - uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, - uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, - uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, - uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, - uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, - uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, - uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, - uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, - uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, - uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, - uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, - uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, - uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, - uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, - uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, - uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, - uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, - uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, - uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, - uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, - uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, - uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, - uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, - uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %114, 0;\n" - "wgmma.mma_async.sync.aligned.m64n224k32.s32.s8.u8 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87, " - " %88, %89, %90, %91, %92, %93, %94, %95, " - " %96, %97, %98, %99, %100, %101, %102, %103, " - " %104, %105, %106, %107, %108, %109, %110, %111}," - " %112," - " %113," - " p;\n" - "}\n" - : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), - "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), - "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), - "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), - "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), - "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), - "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), - "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), - "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), - "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), - "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), - "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), - "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), - "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), - "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), - "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), - "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), - "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), - "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), - "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), - "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), - "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), - "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), - "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), - "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), - "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), - "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), - "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111) - : "l"(desc_a), - "l"(desc_b), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x224x32_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x224x32 TN S32+=S8*U8 -struct MMA_64x224x32_S32S8U8_SS_TN_SATURATE -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[112]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, - uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, - uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, - uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, - uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, - uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, - uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, - uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, - uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, - uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, - uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, - uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, - uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, - uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, - uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, - uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, - uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, - uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, - uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, - uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, - uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, - uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, - uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, - uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, - uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, - uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, - uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, - uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %114, 0;\n" - "wgmma.mma_async.sync.aligned.m64n224k32.s32.s8.u8.satfinite " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87, " - " %88, %89, %90, %91, %92, %93, %94, %95, " - " %96, %97, %98, %99, %100, %101, %102, %103, " - " %104, %105, %106, %107, %108, %109, %110, %111}," - " %112," - " %113," - " p;\n" - "}\n" - : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), - "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), - "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), - "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), - "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), - "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), - "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), - "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), - "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), - "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), - "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), - "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), - "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), - "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), - "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), - "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), - "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), - "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), - "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), - "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), - "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), - "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), - "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), - "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), - "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), - "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), - "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), - "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111) - : "l"(desc_a), - "l"(desc_b), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x224x32_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x240x32 TN S32+=S8*U8 -struct MMA_64x240x32_S32S8U8_SS_TN -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[120]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, - uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, - uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, - uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, - uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, - uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, - uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, - uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, - uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, - uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, - uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, - uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, - uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, - uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, - uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, - uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, - uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, - uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, - uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, - uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, - uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, - uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, - uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, - uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, - uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, - uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, - uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, - uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, - uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, - uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %122, 0;\n" - "wgmma.mma_async.sync.aligned.m64n240k32.s32.s8.u8 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87, " - " %88, %89, %90, %91, %92, %93, %94, %95, " - " %96, %97, %98, %99, %100, %101, %102, %103, " - " %104, %105, %106, %107, %108, %109, %110, %111, " - " %112, %113, %114, %115, %116, %117, %118, %119}," - " %120," - " %121," - " p;\n" - "}\n" - : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), - "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), - "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), - "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), - "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), - "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), - "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), - "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), - "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), - "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), - "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), - "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), - "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), - "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), - "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), - "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), - "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), - "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), - "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), - "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), - "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), - "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), - "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), - "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), - "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), - "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), - "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), - "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), - "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), - "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119) - : "l"(desc_a), - "l"(desc_b), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x240x32_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x240x32 TN S32+=S8*U8 -struct MMA_64x240x32_S32S8U8_SS_TN_SATURATE -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[120]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, - uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, - uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, - uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, - uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, - uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, - uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, - uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, - uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, - uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, - uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, - uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, - uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, - uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, - uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, - uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, - uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, - uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, - uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, - uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, - uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, - uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, - uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, - uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, - uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, - uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, - uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, - uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, - uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, - uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %122, 0;\n" - "wgmma.mma_async.sync.aligned.m64n240k32.s32.s8.u8.satfinite " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87, " - " %88, %89, %90, %91, %92, %93, %94, %95, " - " %96, %97, %98, %99, %100, %101, %102, %103, " - " %104, %105, %106, %107, %108, %109, %110, %111, " - " %112, %113, %114, %115, %116, %117, %118, %119}," - " %120," - " %121," - " p;\n" - "}\n" - : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), - "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), - "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), - "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), - "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), - "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), - "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), - "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), - "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), - "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), - "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), - "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), - "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), - "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), - "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), - "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), - "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), - "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), - "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), - "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), - "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), - "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), - "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), - "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), - "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), - "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), - "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), - "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), - "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), - "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119) - : "l"(desc_a), - "l"(desc_b), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x240x32_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// GMMA 64x256x32 TN S32+=S8*U8 -struct MMA_64x256x32_S32S8U8_SS_TN -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[128]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, - uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, - uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, - uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, - uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, - uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, - uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, - uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, - uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, - uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, - uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, - uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, - uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, - uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, - uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, - uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, - uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, - uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, - uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, - uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, - uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, - uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, - uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, - uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, - uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, - uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, - uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, - uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, - uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, - uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, - uint32_t & d120, uint32_t & d121, uint32_t & d122, uint32_t & d123, - uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %130, 0;\n" - "wgmma.mma_async.sync.aligned.m64n256k32.s32.s8.u8 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87, " - " %88, %89, %90, %91, %92, %93, %94, %95, " - " %96, %97, %98, %99, %100, %101, %102, %103, " - " %104, %105, %106, %107, %108, %109, %110, %111, " - " %112, %113, %114, %115, %116, %117, %118, %119, " - " %120, %121, %122, %123, %124, %125, %126, %127}," - " %128," - " %129," - " p;\n" - "}\n" - : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), - "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), - "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), - "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), - "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), - "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), - "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), - "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), - "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), - "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), - "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), - "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), - "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), - "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), - "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), - "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), - "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), - "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), - "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), - "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), - "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), - "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), - "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), - "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), - "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), - "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), - "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), - "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), - "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), - "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119), - "+r"(d120), "+r"(d121), "+r"(d122), "+r"(d123), - "+r"(d124), "+r"(d125), "+r"(d126), "+r"(d127) - : "l"(desc_a), - "l"(desc_b), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x256x32_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// GMMA 64x256x32 TN S32+=S8*U8 -struct MMA_64x256x32_S32S8U8_SS_TN_SATURATE -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[128]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, - uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, - uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, - uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, - uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, - uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, - uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, - uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, - uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, - uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, - uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, - uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, - uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, - uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, - uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, - uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, - uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, - uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, - uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, - uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, - uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, - uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, - uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, - uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, - uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, - uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, - uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, - uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, - uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, - uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, - uint32_t & d120, uint32_t & d121, uint32_t & d122, uint32_t & d123, - uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %130, 0;\n" - "wgmma.mma_async.sync.aligned.m64n256k32.s32.s8.u8.satfinite " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87, " - " %88, %89, %90, %91, %92, %93, %94, %95, " - " %96, %97, %98, %99, %100, %101, %102, %103, " - " %104, %105, %106, %107, %108, %109, %110, %111, " - " %112, %113, %114, %115, %116, %117, %118, %119, " - " %120, %121, %122, %123, %124, %125, %126, %127}," - " %128," - " %129," - " p;\n" - "}\n" - : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), - "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), - "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), - "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), - "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), - "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), - "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), - "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), - "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), - "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), - "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), - "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), - "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), - "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), - "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), - "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), - "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), - "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), - "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), - "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), - "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), - "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), - "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), - "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), - "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), - "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), - "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), - "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), - "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), - "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119), - "+r"(d120), "+r"(d121), "+r"(d122), "+r"(d123), - "+r"(d124), "+r"(d125), "+r"(d126), "+r"(d127) - : "l"(desc_a), - "l"(desc_b), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x256x32_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// GMMA 64x8x32 TN S32+=S8*U8 -struct MMA_64x8x32_S32S8U8_RS_TN -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[4]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, - uint64_t const& desc_b, - uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %9, 0;\n" - "wgmma.mma_async.sync.aligned.m64n8k32.s32.s8.u8 " - "{%0, %1, %2, %3}," - "{%4, %5, %6, %7}," - " %8," - " p;\n" - "}\n" - : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) - : "r"(a0), "r"(a1), "r"(a2), "r"(a3), - "l"(desc_b), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x8x32_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// GMMA 64x8x32 TN S32+=S8*U8 -struct MMA_64x8x32_S32S8U8_RS_TN_SATURATE -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[4]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, - uint64_t const& desc_b, - uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %9, 0;\n" - "wgmma.mma_async.sync.aligned.m64n8k32.s32.s8.u8.satfinite " - "{%0, %1, %2, %3}," - "{%4, %5, %6, %7}," - " %8," - " p;\n" - "}\n" - : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) - : "r"(a0), "r"(a1), "r"(a2), "r"(a3), - "l"(desc_b), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x8x32_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// GMMA 64x16x32 TN S32+=S8*U8 -struct MMA_64x16x32_S32S8U8_RS_TN -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[8]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, - uint64_t const& desc_b, - uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, - uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %13, 0;\n" - "wgmma.mma_async.sync.aligned.m64n16k32.s32.s8.u8 " - "{%0, %1, %2, %3, %4, %5, %6, %7}," - "{%8, %9, %10, %11}," - " %12," - " p;\n" - "}\n" - : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), - "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) - : "r"(a0), "r"(a1), "r"(a2), "r"(a3), - "l"(desc_b), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x16x32_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// GMMA 64x16x32 TN S32+=S8*U8 -struct MMA_64x16x32_S32S8U8_RS_TN_SATURATE -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[8]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, - uint64_t const& desc_b, - uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, - uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %13, 0;\n" - "wgmma.mma_async.sync.aligned.m64n16k32.s32.s8.u8.satfinite " - "{%0, %1, %2, %3, %4, %5, %6, %7}," - "{%8, %9, %10, %11}," - " %12," - " p;\n" - "}\n" - : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), - "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) - : "r"(a0), "r"(a1), "r"(a2), "r"(a3), - "l"(desc_b), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x16x32_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// GMMA 64x32x32 TN S32+=S8*U8 -struct MMA_64x32x32_S32S8U8_RS_TN -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[16]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %21, 0;\n" - "wgmma.mma_async.sync.aligned.m64n32k32.s32.s8.u8 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15}," - "{%16, %17, %18, %19}," - " %20," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), - "l"(desc_b), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x32x32_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// GMMA 64x32x32 TN S32+=S8*U8 -struct MMA_64x32x32_S32S8U8_RS_TN_SATURATE -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[16]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %21, 0;\n" - "wgmma.mma_async.sync.aligned.m64n32k32.s32.s8.u8.satfinite " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15}," - "{%16, %17, %18, %19}," - " %20," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), - "l"(desc_b), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x32x32_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x48x32 TN S32+=S8*U8 -struct MMA_64x48x32_S32S8U8_RS_TN -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[24]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %29, 0;\n" - "wgmma.mma_async.sync.aligned.m64n48k32.s32.s8.u8 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23}," - "{%24, %25, %26, %27}," - " %28," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), - "l"(desc_b), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x48x32_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x48x32 TN S32+=S8*U8 -struct MMA_64x48x32_S32S8U8_RS_TN_SATURATE -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[24]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %29, 0;\n" - "wgmma.mma_async.sync.aligned.m64n48k32.s32.s8.u8.satfinite " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23}," - "{%24, %25, %26, %27}," - " %28," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), - "l"(desc_b), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x48x32_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// GMMA 64x64x32 TN S32+=S8*U8 -struct MMA_64x64x32_S32S8U8_RS_TN -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[32]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %37, 0;\n" - "wgmma.mma_async.sync.aligned.m64n64k32.s32.s8.u8 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31}," - "{%32, %33, %34, %35}," - " %36," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), - "l"(desc_b), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x64x32_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// GMMA 64x64x32 TN S32+=S8*U8 -struct MMA_64x64x32_S32S8U8_RS_TN_SATURATE -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[32]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %37, 0;\n" - "wgmma.mma_async.sync.aligned.m64n64k32.s32.s8.u8.satfinite " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31}," - "{%32, %33, %34, %35}," - " %36," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), - "l"(desc_b), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x64x32_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x80x32 TN S32+=S8*U8 -struct MMA_64x80x32_S32S8U8_RS_TN -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[40]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %45, 0;\n" - "wgmma.mma_async.sync.aligned.m64n80k32.s32.s8.u8 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39}," - "{%40, %41, %42, %43}," - " %44," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), - "l"(desc_b), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x80x32_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x80x32 TN S32+=S8*U8 -struct MMA_64x80x32_S32S8U8_RS_TN_SATURATE -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[40]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %45, 0;\n" - "wgmma.mma_async.sync.aligned.m64n80k32.s32.s8.u8.satfinite " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39}," - "{%40, %41, %42, %43}," - " %44," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), - "l"(desc_b), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x80x32_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// GMMA 64x96x32 TN S32+=S8*U8 -struct MMA_64x96x32_S32S8U8_RS_TN -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[48]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %53, 0;\n" - "wgmma.mma_async.sync.aligned.m64n96k32.s32.s8.u8 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47}," - "{%48, %49, %50, %51}," - " %52," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), - "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), - "l"(desc_b), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x96x32_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// GMMA 64x96x32 TN S32+=S8*U8 -struct MMA_64x96x32_S32S8U8_RS_TN_SATURATE -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[48]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %53, 0;\n" - "wgmma.mma_async.sync.aligned.m64n96k32.s32.s8.u8.satfinite " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47}," - "{%48, %49, %50, %51}," - " %52," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), - "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), - "l"(desc_b), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x96x32_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x112x32 TN S32+=S8*U8 -struct MMA_64x112x32_S32S8U8_RS_TN -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[56]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, - uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, - uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %61, 0;\n" - "wgmma.mma_async.sync.aligned.m64n112k32.s32.s8.u8 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55}," - "{%56, %57, %58, %59}," - " %60," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), - "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), - "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), - "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), - "l"(desc_b), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x112x32_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x112x32 TN S32+=S8*U8 -struct MMA_64x112x32_S32S8U8_RS_TN_SATURATE -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[56]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, - uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, - uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %61, 0;\n" - "wgmma.mma_async.sync.aligned.m64n112k32.s32.s8.u8.satfinite " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55}," - "{%56, %57, %58, %59}," - " %60," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), - "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), - "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), - "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), - "l"(desc_b), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x112x32_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// GMMA 64x128x32 TN S32+=S8*U8 -struct MMA_64x128x32_S32S8U8_RS_TN -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[64]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, - uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, - uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, - uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, - uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %69, 0;\n" - "wgmma.mma_async.sync.aligned.m64n128k32.s32.s8.u8 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63}," - "{%64, %65, %66, %67}," - " %68," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), - "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), - "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), - "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), - "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), - "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), - "l"(desc_b), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x128x32_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// GMMA 64x128x32 TN S32+=S8*U8 -struct MMA_64x128x32_S32S8U8_RS_TN_SATURATE -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[64]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, - uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, - uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, - uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, - uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %69, 0;\n" - "wgmma.mma_async.sync.aligned.m64n128k32.s32.s8.u8.satfinite " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63}," - "{%64, %65, %66, %67}," - " %68," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), - "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), - "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), - "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), - "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), - "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), - "l"(desc_b), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x128x32_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x144x32 TN S32+=S8*U8 -struct MMA_64x144x32_S32S8U8_RS_TN -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[72]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, - uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, - uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, - uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, - uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, - uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, - uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %77, 0;\n" - "wgmma.mma_async.sync.aligned.m64n144k32.s32.s8.u8 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71}," - "{%72, %73, %74, %75}," - " %76," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), - "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), - "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), - "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), - "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), - "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), - "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), - "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), - "l"(desc_b), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x144x32_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x144x32 TN S32+=S8*U8 -struct MMA_64x144x32_S32S8U8_RS_TN_SATURATE -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[72]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, - uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, - uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, - uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, - uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, - uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, - uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %77, 0;\n" - "wgmma.mma_async.sync.aligned.m64n144k32.s32.s8.u8.satfinite " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71}," - "{%72, %73, %74, %75}," - " %76," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), - "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), - "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), - "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), - "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), - "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), - "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), - "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), - "l"(desc_b), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x144x32_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x160x32 TN S32+=S8*U8 -struct MMA_64x160x32_S32S8U8_RS_TN -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[80]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, - uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, - uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, - uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, - uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, - uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, - uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, - uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, - uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %85, 0;\n" - "wgmma.mma_async.sync.aligned.m64n160k32.s32.s8.u8 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79}," - "{%80, %81, %82, %83}," - " %84," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), - "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), - "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), - "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), - "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), - "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), - "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), - "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), - "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), - "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), - "l"(desc_b), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x160x32_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x160x32 TN S32+=S8*U8 -struct MMA_64x160x32_S32S8U8_RS_TN_SATURATE -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[80]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, - uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, - uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, - uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, - uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, - uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, - uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, - uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, - uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %85, 0;\n" - "wgmma.mma_async.sync.aligned.m64n160k32.s32.s8.u8.satfinite " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79}," - "{%80, %81, %82, %83}," - " %84," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), - "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), - "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), - "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), - "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), - "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), - "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), - "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), - "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), - "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), - "l"(desc_b), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x160x32_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x176x32 TN S32+=S8*U8 -struct MMA_64x176x32_S32S8U8_RS_TN -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[88]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, - uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, - uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, - uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, - uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, - uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, - uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, - uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, - uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, - uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, - uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %93, 0;\n" - "wgmma.mma_async.sync.aligned.m64n176k32.s32.s8.u8 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87}," - "{%88, %89, %90, %91}," - " %92," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), - "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), - "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), - "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), - "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), - "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), - "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), - "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), - "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), - "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), - "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), - "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), - "l"(desc_b), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x176x32_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x176x32 TN S32+=S8*U8 -struct MMA_64x176x32_S32S8U8_RS_TN_SATURATE -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[88]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, - uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, - uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, - uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, - uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, - uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, - uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, - uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, - uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, - uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, - uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %93, 0;\n" - "wgmma.mma_async.sync.aligned.m64n176k32.s32.s8.u8.satfinite " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87}," - "{%88, %89, %90, %91}," - " %92," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), - "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), - "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), - "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), - "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), - "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), - "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), - "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), - "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), - "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), - "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), - "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), - "l"(desc_b), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x176x32_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// GMMA 64x192x32 TN S32+=S8*U8 -struct MMA_64x192x32_S32S8U8_RS_TN -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[96]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, - uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, - uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, - uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, - uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, - uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, - uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, - uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, - uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, - uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, - uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, - uint32_t & d88, uint32_t & d89, uint32_t & d90, uint32_t & d91, - uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %101, 0;\n" - "wgmma.mma_async.sync.aligned.m64n192k32.s32.s8.u8 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87, " - " %88, %89, %90, %91, %92, %93, %94, %95}," - "{%96, %97, %98, %99}," - " %100," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), - "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), - "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), - "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), - "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), - "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), - "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), - "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), - "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), - "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), - "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), - "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87), - "+r"(d88), "+r"(d89), "+r"(d90), "+r"(d91), - "+r"(d92), "+r"(d93), "+r"(d94), "+r"(d95) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), - "l"(desc_b), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x192x32_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// GMMA 64x192x32 TN S32+=S8*U8 -struct MMA_64x192x32_S32S8U8_RS_TN_SATURATE -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[96]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, - uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, - uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, - uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, - uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, - uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, - uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, - uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, - uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, - uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, - uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, - uint32_t & d88, uint32_t & d89, uint32_t & d90, uint32_t & d91, - uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %101, 0;\n" - "wgmma.mma_async.sync.aligned.m64n192k32.s32.s8.u8.satfinite " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87, " - " %88, %89, %90, %91, %92, %93, %94, %95}," - "{%96, %97, %98, %99}," - " %100," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), - "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), - "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), - "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), - "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), - "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), - "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), - "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), - "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), - "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), - "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), - "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87), - "+r"(d88), "+r"(d89), "+r"(d90), "+r"(d91), - "+r"(d92), "+r"(d93), "+r"(d94), "+r"(d95) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), - "l"(desc_b), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x192x32_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x208x32 TN S32+=S8*U8 -struct MMA_64x208x32_S32S8U8_RS_TN -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[104]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, - uint64_t const& desc_b, - uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, - uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, - uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, - uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, - uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, - uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, - uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, - uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, - uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, - uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, - uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, - uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, - uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, - uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, - uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, - uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, - uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, - uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, - uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, - uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, - uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, - uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, - uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, - uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, - uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, - uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %109, 0;\n" - "wgmma.mma_async.sync.aligned.m64n208k32.s32.s8.u8 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87, " - " %88, %89, %90, %91, %92, %93, %94, %95, " - " %96, %97, %98, %99, %100, %101, %102, %103}," - "{%104, %105, %106, %107}," - " %108," - " p;\n" - "}\n" - : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), - "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), - "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), - "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), - "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), - "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), - "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), - "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), - "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), - "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), - "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), - "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), - "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), - "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), - "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), - "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), - "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), - "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), - "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), - "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), - "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), - "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), - "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), - "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), - "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), - "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103) - : "r"(a000), "r"(a001), "r"(a002), "r"(a003), - "l"(desc_b), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x208x32_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x208x32 TN S32+=S8*U8 -struct MMA_64x208x32_S32S8U8_RS_TN_SATURATE -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[104]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, - uint64_t const& desc_b, - uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, - uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, - uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, - uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, - uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, - uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, - uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, - uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, - uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, - uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, - uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, - uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, - uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, - uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, - uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, - uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, - uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, - uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, - uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, - uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, - uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, - uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, - uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, - uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, - uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, - uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %109, 0;\n" - "wgmma.mma_async.sync.aligned.m64n208k32.s32.s8.u8.satfinite " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87, " - " %88, %89, %90, %91, %92, %93, %94, %95, " - " %96, %97, %98, %99, %100, %101, %102, %103}," - "{%104, %105, %106, %107}," - " %108," - " p;\n" - "}\n" - : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), - "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), - "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), - "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), - "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), - "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), - "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), - "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), - "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), - "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), - "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), - "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), - "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), - "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), - "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), - "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), - "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), - "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), - "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), - "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), - "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), - "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), - "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), - "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), - "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), - "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103) - : "r"(a000), "r"(a001), "r"(a002), "r"(a003), - "l"(desc_b), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x208x32_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x224x32 TN S32+=S8*U8 -struct MMA_64x224x32_S32S8U8_RS_TN -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[112]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, - uint64_t const& desc_b, - uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, - uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, - uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, - uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, - uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, - uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, - uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, - uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, - uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, - uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, - uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, - uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, - uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, - uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, - uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, - uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, - uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, - uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, - uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, - uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, - uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, - uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, - uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, - uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, - uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, - uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, - uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, - uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %117, 0;\n" - "wgmma.mma_async.sync.aligned.m64n224k32.s32.s8.u8 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87, " - " %88, %89, %90, %91, %92, %93, %94, %95, " - " %96, %97, %98, %99, %100, %101, %102, %103, " - " %104, %105, %106, %107, %108, %109, %110, %111}," - "{%112, %113, %114, %115}," - " %116," - " p;\n" - "}\n" - : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), - "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), - "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), - "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), - "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), - "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), - "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), - "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), - "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), - "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), - "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), - "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), - "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), - "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), - "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), - "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), - "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), - "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), - "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), - "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), - "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), - "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), - "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), - "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), - "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), - "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), - "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), - "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111) - : "r"(a000), "r"(a001), "r"(a002), "r"(a003), - "l"(desc_b), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x224x32_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x224x32 TN S32+=S8*U8 -struct MMA_64x224x32_S32S8U8_RS_TN_SATURATE -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[112]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, - uint64_t const& desc_b, - uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, - uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, - uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, - uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, - uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, - uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, - uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, - uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, - uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, - uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, - uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, - uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, - uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, - uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, - uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, - uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, - uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, - uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, - uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, - uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, - uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, - uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, - uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, - uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, - uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, - uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, - uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, - uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %117, 0;\n" - "wgmma.mma_async.sync.aligned.m64n224k32.s32.s8.u8.satfinite " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87, " - " %88, %89, %90, %91, %92, %93, %94, %95, " - " %96, %97, %98, %99, %100, %101, %102, %103, " - " %104, %105, %106, %107, %108, %109, %110, %111}," - "{%112, %113, %114, %115}," - " %116," - " p;\n" - "}\n" - : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), - "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), - "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), - "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), - "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), - "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), - "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), - "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), - "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), - "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), - "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), - "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), - "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), - "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), - "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), - "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), - "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), - "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), - "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), - "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), - "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), - "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), - "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), - "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), - "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), - "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), - "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), - "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111) - : "r"(a000), "r"(a001), "r"(a002), "r"(a003), - "l"(desc_b), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x224x32_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x240x32 TN S32+=S8*U8 -struct MMA_64x240x32_S32S8U8_RS_TN -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[120]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, - uint64_t const& desc_b, - uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, - uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, - uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, - uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, - uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, - uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, - uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, - uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, - uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, - uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, - uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, - uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, - uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, - uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, - uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, - uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, - uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, - uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, - uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, - uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, - uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, - uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, - uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, - uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, - uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, - uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, - uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, - uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, - uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, - uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %125, 0;\n" - "wgmma.mma_async.sync.aligned.m64n240k32.s32.s8.u8 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87, " - " %88, %89, %90, %91, %92, %93, %94, %95, " - " %96, %97, %98, %99, %100, %101, %102, %103, " - " %104, %105, %106, %107, %108, %109, %110, %111, " - " %112, %113, %114, %115, %116, %117, %118, %119}," - "{%120, %121, %122, %123}," - " %124," - " p;\n" - "}\n" - : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), - "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), - "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), - "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), - "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), - "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), - "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), - "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), - "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), - "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), - "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), - "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), - "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), - "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), - "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), - "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), - "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), - "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), - "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), - "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), - "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), - "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), - "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), - "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), - "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), - "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), - "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), - "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), - "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), - "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119) - : "r"(a000), "r"(a001), "r"(a002), "r"(a003), - "l"(desc_b), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x240x32_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x240x32 TN S32+=S8*U8 -struct MMA_64x240x32_S32S8U8_RS_TN_SATURATE -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[120]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, - uint64_t const& desc_b, - uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, - uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, - uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, - uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, - uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, - uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, - uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, - uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, - uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, - uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, - uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, - uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, - uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, - uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, - uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, - uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, - uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, - uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, - uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, - uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, - uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, - uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, - uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, - uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, - uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, - uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, - uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, - uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, - uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, - uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %125, 0;\n" - "wgmma.mma_async.sync.aligned.m64n240k32.s32.s8.u8.satfinite " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87, " - " %88, %89, %90, %91, %92, %93, %94, %95, " - " %96, %97, %98, %99, %100, %101, %102, %103, " - " %104, %105, %106, %107, %108, %109, %110, %111, " - " %112, %113, %114, %115, %116, %117, %118, %119}," - "{%120, %121, %122, %123}," - " %124," - " p;\n" - "}\n" - : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), - "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), - "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), - "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), - "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), - "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), - "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), - "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), - "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), - "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), - "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), - "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), - "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), - "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), - "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), - "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), - "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), - "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), - "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), - "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), - "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), - "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), - "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), - "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), - "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), - "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), - "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), - "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), - "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), - "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119) - : "r"(a000), "r"(a001), "r"(a002), "r"(a003), - "l"(desc_b), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x240x32_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// GMMA 64x256x32 TN S32+=S8*U8 -struct MMA_64x256x32_S32S8U8_RS_TN -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[128]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, - uint64_t const& desc_b, - uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, - uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, - uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, - uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, - uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, - uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, - uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, - uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, - uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, - uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, - uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, - uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, - uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, - uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, - uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, - uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, - uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, - uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, - uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, - uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, - uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, - uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, - uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, - uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, - uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, - uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, - uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, - uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, - uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, - uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, - uint32_t & d120, uint32_t & d121, uint32_t & d122, uint32_t & d123, - uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %133, 0;\n" - "wgmma.mma_async.sync.aligned.m64n256k32.s32.s8.u8 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87, " - " %88, %89, %90, %91, %92, %93, %94, %95, " - " %96, %97, %98, %99, %100, %101, %102, %103, " - " %104, %105, %106, %107, %108, %109, %110, %111, " - " %112, %113, %114, %115, %116, %117, %118, %119, " - " %120, %121, %122, %123, %124, %125, %126, %127}," - "{%128, %129, %130, %131}," - " %132," - " p;\n" - "}\n" - : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), - "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), - "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), - "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), - "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), - "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), - "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), - "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), - "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), - "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), - "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), - "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), - "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), - "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), - "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), - "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), - "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), - "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), - "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), - "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), - "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), - "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), - "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), - "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), - "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), - "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), - "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), - "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), - "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), - "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119), - "+r"(d120), "+r"(d121), "+r"(d122), "+r"(d123), - "+r"(d124), "+r"(d125), "+r"(d126), "+r"(d127) - : "r"(a000), "r"(a001), "r"(a002), "r"(a003), - "l"(desc_b), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x256x32_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// GMMA 64x256x32 TN S32+=S8*U8 -struct MMA_64x256x32_S32S8U8_RS_TN_SATURATE -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[128]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, - uint64_t const& desc_b, - uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, - uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, - uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, - uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, - uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, - uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, - uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, - uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, - uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, - uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, - uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, - uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, - uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, - uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, - uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, - uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, - uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, - uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, - uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, - uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, - uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, - uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, - uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, - uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, - uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, - uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, - uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, - uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, - uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, - uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, - uint32_t & d120, uint32_t & d121, uint32_t & d122, uint32_t & d123, - uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %133, 0;\n" - "wgmma.mma_async.sync.aligned.m64n256k32.s32.s8.u8.satfinite " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87, " - " %88, %89, %90, %91, %92, %93, %94, %95, " - " %96, %97, %98, %99, %100, %101, %102, %103, " - " %104, %105, %106, %107, %108, %109, %110, %111, " - " %112, %113, %114, %115, %116, %117, %118, %119, " - " %120, %121, %122, %123, %124, %125, %126, %127}," - "{%128, %129, %130, %131}," - " %132," - " p;\n" - "}\n" - : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), - "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), - "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), - "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), - "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), - "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), - "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), - "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), - "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), - "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), - "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), - "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), - "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), - "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), - "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), - "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), - "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), - "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), - "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), - "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), - "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), - "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), - "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), - "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), - "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), - "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), - "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), - "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), - "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), - "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119), - "+r"(d120), "+r"(d121), "+r"(d122), "+r"(d123), - "+r"(d124), "+r"(d125), "+r"(d126), "+r"(d127) - : "r"(a000), "r"(a001), "r"(a002), "r"(a003), - "l"(desc_b), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x256x32_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// GMMA 64x8x32 TN S32+=U8*S8 -struct MMA_64x8x32_S32U8S8_SS_TN -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[4]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %6, 0;\n" - "wgmma.mma_async.sync.aligned.m64n8k32.s32.u8.s8 " - "{%0, %1, %2, %3}," - " %4," - " %5," - " p;\n" - "}\n" - : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) - : "l"(desc_a), - "l"(desc_b), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x8x32_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// GMMA 64x8x32 TN S32+=U8*S8 -struct MMA_64x8x32_S32U8S8_SS_TN_SATURATE -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[4]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %6, 0;\n" - "wgmma.mma_async.sync.aligned.m64n8k32.s32.u8.s8.satfinite " - "{%0, %1, %2, %3}," - " %4," - " %5," - " p;\n" - "}\n" - : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) - : "l"(desc_a), - "l"(desc_b), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x8x32_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// GMMA 64x16x32 TN S32+=U8*S8 -struct MMA_64x16x32_S32U8S8_SS_TN -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[8]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, - uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %10, 0;\n" - "wgmma.mma_async.sync.aligned.m64n16k32.s32.u8.s8 " - "{%0, %1, %2, %3, %4, %5, %6, %7}," - " %8," - " %9," - " p;\n" - "}\n" - : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), - "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) - : "l"(desc_a), - "l"(desc_b), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x16x32_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// GMMA 64x16x32 TN S32+=U8*S8 -struct MMA_64x16x32_S32U8S8_SS_TN_SATURATE -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[8]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, - uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %10, 0;\n" - "wgmma.mma_async.sync.aligned.m64n16k32.s32.u8.s8.satfinite " - "{%0, %1, %2, %3, %4, %5, %6, %7}," - " %8," - " %9," - " p;\n" - "}\n" - : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), - "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) - : "l"(desc_a), - "l"(desc_b), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x16x32_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// GMMA 64x32x32 TN S32+=U8*S8 -struct MMA_64x32x32_S32U8S8_SS_TN -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[16]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %18, 0;\n" - "wgmma.mma_async.sync.aligned.m64n32k32.s32.u8.s8 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15}," - " %16," - " %17," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) - : "l"(desc_a), - "l"(desc_b), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x32x32_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// GMMA 64x32x32 TN S32+=U8*S8 -struct MMA_64x32x32_S32U8S8_SS_TN_SATURATE -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[16]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %18, 0;\n" - "wgmma.mma_async.sync.aligned.m64n32k32.s32.u8.s8.satfinite " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15}," - " %16," - " %17," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) - : "l"(desc_a), - "l"(desc_b), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x32x32_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x48x32 TN S32+=U8*S8 -struct MMA_64x48x32_S32U8S8_SS_TN -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[24]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %26, 0;\n" - "wgmma.mma_async.sync.aligned.m64n48k32.s32.u8.s8 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23}," - " %24," - " %25," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) - : "l"(desc_a), - "l"(desc_b), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x48x32_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x48x32 TN S32+=U8*S8 -struct MMA_64x48x32_S32U8S8_SS_TN_SATURATE -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[24]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %26, 0;\n" - "wgmma.mma_async.sync.aligned.m64n48k32.s32.u8.s8.satfinite " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23}," - " %24," - " %25," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) - : "l"(desc_a), - "l"(desc_b), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x48x32_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// GMMA 64x64x32 TN S32+=U8*S8 -struct MMA_64x64x32_S32U8S8_SS_TN -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[32]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %34, 0;\n" - "wgmma.mma_async.sync.aligned.m64n64k32.s32.u8.s8 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31}," - " %32," - " %33," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) - : "l"(desc_a), - "l"(desc_b), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x64x32_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// GMMA 64x64x32 TN S32+=U8*S8 -struct MMA_64x64x32_S32U8S8_SS_TN_SATURATE -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[32]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %34, 0;\n" - "wgmma.mma_async.sync.aligned.m64n64k32.s32.u8.s8.satfinite " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31}," - " %32," - " %33," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) - : "l"(desc_a), - "l"(desc_b), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x64x32_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x80x32 TN S32+=U8*S8 -struct MMA_64x80x32_S32U8S8_SS_TN -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[40]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %42, 0;\n" - "wgmma.mma_async.sync.aligned.m64n80k32.s32.u8.s8 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39}," - " %40," - " %41," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) - : "l"(desc_a), - "l"(desc_b), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x80x32_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x80x32 TN S32+=U8*S8 -struct MMA_64x80x32_S32U8S8_SS_TN_SATURATE -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[40]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %42, 0;\n" - "wgmma.mma_async.sync.aligned.m64n80k32.s32.u8.s8.satfinite " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39}," - " %40," - " %41," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) - : "l"(desc_a), - "l"(desc_b), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x80x32_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// GMMA 64x96x32 TN S32+=U8*S8 -struct MMA_64x96x32_S32U8S8_SS_TN -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[48]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %50, 0;\n" - "wgmma.mma_async.sync.aligned.m64n96k32.s32.u8.s8 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47}," - " %48," - " %49," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), - "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) - : "l"(desc_a), - "l"(desc_b), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x96x32_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// GMMA 64x96x32 TN S32+=U8*S8 -struct MMA_64x96x32_S32U8S8_SS_TN_SATURATE -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[48]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %50, 0;\n" - "wgmma.mma_async.sync.aligned.m64n96k32.s32.u8.s8.satfinite " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47}," - " %48," - " %49," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), - "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) - : "l"(desc_a), - "l"(desc_b), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x96x32_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x112x32 TN S32+=U8*S8 -struct MMA_64x112x32_S32U8S8_SS_TN -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[56]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, - uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, - uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %58, 0;\n" - "wgmma.mma_async.sync.aligned.m64n112k32.s32.u8.s8 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55}," - " %56," - " %57," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), - "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), - "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), - "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) - : "l"(desc_a), - "l"(desc_b), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x112x32_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x112x32 TN S32+=U8*S8 -struct MMA_64x112x32_S32U8S8_SS_TN_SATURATE -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[56]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, - uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, - uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %58, 0;\n" - "wgmma.mma_async.sync.aligned.m64n112k32.s32.u8.s8.satfinite " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55}," - " %56," - " %57," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), - "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), - "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), - "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) - : "l"(desc_a), - "l"(desc_b), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x112x32_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// GMMA 64x128x32 TN S32+=U8*S8 -struct MMA_64x128x32_S32U8S8_SS_TN -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[64]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, - uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, - uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, - uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, - uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %66, 0;\n" - "wgmma.mma_async.sync.aligned.m64n128k32.s32.u8.s8 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63}," - " %64," - " %65," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), - "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), - "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), - "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), - "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), - "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) - : "l"(desc_a), - "l"(desc_b), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x128x32_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// GMMA 64x128x32 TN S32+=U8*S8 -struct MMA_64x128x32_S32U8S8_SS_TN_SATURATE -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[64]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, - uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, - uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, - uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, - uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %66, 0;\n" - "wgmma.mma_async.sync.aligned.m64n128k32.s32.u8.s8.satfinite " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63}," - " %64," - " %65," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), - "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), - "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), - "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), - "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), - "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) - : "l"(desc_a), - "l"(desc_b), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x128x32_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x144x32 TN S32+=U8*S8 -struct MMA_64x144x32_S32U8S8_SS_TN -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[72]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, - uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, - uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, - uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, - uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, - uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, - uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %74, 0;\n" - "wgmma.mma_async.sync.aligned.m64n144k32.s32.u8.s8 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71}," - " %72," - " %73," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), - "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), - "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), - "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), - "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), - "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), - "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), - "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71) - : "l"(desc_a), - "l"(desc_b), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x144x32_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x144x32 TN S32+=U8*S8 -struct MMA_64x144x32_S32U8S8_SS_TN_SATURATE -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[72]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, - uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, - uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, - uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, - uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, - uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, - uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %74, 0;\n" - "wgmma.mma_async.sync.aligned.m64n144k32.s32.u8.s8.satfinite " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71}," - " %72," - " %73," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), - "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), - "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), - "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), - "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), - "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), - "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), - "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71) - : "l"(desc_a), - "l"(desc_b), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x144x32_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x160x32 TN S32+=U8*S8 -struct MMA_64x160x32_S32U8S8_SS_TN -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[80]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, - uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, - uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, - uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, - uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, - uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, - uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, - uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, - uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %82, 0;\n" - "wgmma.mma_async.sync.aligned.m64n160k32.s32.u8.s8 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79}," - " %80," - " %81," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), - "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), - "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), - "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), - "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), - "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), - "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), - "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), - "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), - "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79) - : "l"(desc_a), - "l"(desc_b), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x160x32_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x160x32 TN S32+=U8*S8 -struct MMA_64x160x32_S32U8S8_SS_TN_SATURATE -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[80]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, - uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, - uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, - uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, - uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, - uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, - uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, - uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, - uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %82, 0;\n" - "wgmma.mma_async.sync.aligned.m64n160k32.s32.u8.s8.satfinite " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79}," - " %80," - " %81," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), - "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), - "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), - "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), - "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), - "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), - "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), - "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), - "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), - "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79) - : "l"(desc_a), - "l"(desc_b), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x160x32_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x176x32 TN S32+=U8*S8 -struct MMA_64x176x32_S32U8S8_SS_TN -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[88]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, - uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, - uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, - uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, - uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, - uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, - uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, - uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, - uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, - uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, - uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %90, 0;\n" - "wgmma.mma_async.sync.aligned.m64n176k32.s32.u8.s8 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87}," - " %88," - " %89," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), - "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), - "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), - "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), - "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), - "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), - "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), - "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), - "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), - "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), - "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), - "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87) - : "l"(desc_a), - "l"(desc_b), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x176x32_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x176x32 TN S32+=U8*S8 -struct MMA_64x176x32_S32U8S8_SS_TN_SATURATE -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[88]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, - uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, - uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, - uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, - uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, - uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, - uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, - uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, - uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, - uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, - uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %90, 0;\n" - "wgmma.mma_async.sync.aligned.m64n176k32.s32.u8.s8.satfinite " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87}," - " %88," - " %89," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), - "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), - "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), - "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), - "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), - "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), - "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), - "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), - "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), - "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), - "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), - "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87) - : "l"(desc_a), - "l"(desc_b), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x176x32_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// GMMA 64x192x32 TN S32+=U8*S8 -struct MMA_64x192x32_S32U8S8_SS_TN -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[96]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, - uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, - uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, - uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, - uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, - uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, - uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, - uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, - uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, - uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, - uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, - uint32_t & d88, uint32_t & d89, uint32_t & d90, uint32_t & d91, - uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %98, 0;\n" - "wgmma.mma_async.sync.aligned.m64n192k32.s32.u8.s8 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87, " - " %88, %89, %90, %91, %92, %93, %94, %95}," - " %96," - " %97," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), - "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), - "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), - "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), - "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), - "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), - "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), - "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), - "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), - "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), - "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), - "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87), - "+r"(d88), "+r"(d89), "+r"(d90), "+r"(d91), - "+r"(d92), "+r"(d93), "+r"(d94), "+r"(d95) - : "l"(desc_a), - "l"(desc_b), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x192x32_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// GMMA 64x192x32 TN S32+=U8*S8 -struct MMA_64x192x32_S32U8S8_SS_TN_SATURATE -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[96]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, - uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, - uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, - uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, - uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, - uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, - uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, - uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, - uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, - uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, - uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, - uint32_t & d88, uint32_t & d89, uint32_t & d90, uint32_t & d91, - uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %98, 0;\n" - "wgmma.mma_async.sync.aligned.m64n192k32.s32.u8.s8.satfinite " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87, " - " %88, %89, %90, %91, %92, %93, %94, %95}," - " %96," - " %97," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), - "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), - "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), - "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), - "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), - "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), - "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), - "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), - "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), - "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), - "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), - "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87), - "+r"(d88), "+r"(d89), "+r"(d90), "+r"(d91), - "+r"(d92), "+r"(d93), "+r"(d94), "+r"(d95) - : "l"(desc_a), - "l"(desc_b), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x192x32_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x208x32 TN S32+=U8*S8 -struct MMA_64x208x32_S32U8S8_SS_TN -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[104]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, - uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, - uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, - uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, - uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, - uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, - uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, - uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, - uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, - uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, - uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, - uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, - uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, - uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, - uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, - uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, - uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, - uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, - uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, - uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, - uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, - uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, - uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, - uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, - uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, - uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %106, 0;\n" - "wgmma.mma_async.sync.aligned.m64n208k32.s32.u8.s8 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87, " - " %88, %89, %90, %91, %92, %93, %94, %95, " - " %96, %97, %98, %99, %100, %101, %102, %103}," - " %104," - " %105," - " p;\n" - "}\n" - : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), - "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), - "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), - "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), - "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), - "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), - "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), - "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), - "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), - "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), - "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), - "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), - "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), - "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), - "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), - "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), - "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), - "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), - "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), - "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), - "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), - "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), - "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), - "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), - "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), - "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103) - : "l"(desc_a), - "l"(desc_b), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x208x32_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x208x32 TN S32+=U8*S8 -struct MMA_64x208x32_S32U8S8_SS_TN_SATURATE -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[104]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, - uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, - uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, - uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, - uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, - uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, - uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, - uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, - uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, - uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, - uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, - uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, - uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, - uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, - uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, - uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, - uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, - uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, - uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, - uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, - uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, - uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, - uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, - uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, - uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, - uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %106, 0;\n" - "wgmma.mma_async.sync.aligned.m64n208k32.s32.u8.s8.satfinite " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87, " - " %88, %89, %90, %91, %92, %93, %94, %95, " - " %96, %97, %98, %99, %100, %101, %102, %103}," - " %104," - " %105," - " p;\n" - "}\n" - : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), - "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), - "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), - "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), - "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), - "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), - "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), - "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), - "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), - "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), - "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), - "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), - "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), - "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), - "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), - "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), - "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), - "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), - "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), - "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), - "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), - "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), - "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), - "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), - "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), - "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103) - : "l"(desc_a), - "l"(desc_b), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x208x32_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x224x32 TN S32+=U8*S8 -struct MMA_64x224x32_S32U8S8_SS_TN -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[112]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, - uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, - uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, - uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, - uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, - uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, - uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, - uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, - uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, - uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, - uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, - uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, - uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, - uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, - uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, - uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, - uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, - uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, - uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, - uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, - uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, - uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, - uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, - uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, - uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, - uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, - uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, - uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %114, 0;\n" - "wgmma.mma_async.sync.aligned.m64n224k32.s32.u8.s8 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87, " - " %88, %89, %90, %91, %92, %93, %94, %95, " - " %96, %97, %98, %99, %100, %101, %102, %103, " - " %104, %105, %106, %107, %108, %109, %110, %111}," - " %112," - " %113," - " p;\n" - "}\n" - : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), - "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), - "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), - "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), - "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), - "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), - "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), - "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), - "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), - "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), - "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), - "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), - "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), - "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), - "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), - "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), - "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), - "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), - "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), - "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), - "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), - "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), - "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), - "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), - "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), - "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), - "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), - "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111) - : "l"(desc_a), - "l"(desc_b), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x224x32_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x224x32 TN S32+=U8*S8 -struct MMA_64x224x32_S32U8S8_SS_TN_SATURATE -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[112]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, - uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, - uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, - uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, - uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, - uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, - uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, - uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, - uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, - uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, - uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, - uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, - uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, - uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, - uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, - uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, - uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, - uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, - uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, - uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, - uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, - uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, - uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, - uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, - uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, - uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, - uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, - uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %114, 0;\n" - "wgmma.mma_async.sync.aligned.m64n224k32.s32.u8.s8.satfinite " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87, " - " %88, %89, %90, %91, %92, %93, %94, %95, " - " %96, %97, %98, %99, %100, %101, %102, %103, " - " %104, %105, %106, %107, %108, %109, %110, %111}," - " %112," - " %113," - " p;\n" - "}\n" - : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), - "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), - "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), - "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), - "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), - "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), - "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), - "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), - "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), - "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), - "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), - "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), - "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), - "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), - "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), - "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), - "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), - "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), - "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), - "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), - "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), - "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), - "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), - "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), - "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), - "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), - "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), - "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111) - : "l"(desc_a), - "l"(desc_b), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x224x32_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x240x32 TN S32+=U8*S8 -struct MMA_64x240x32_S32U8S8_SS_TN -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[120]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, - uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, - uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, - uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, - uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, - uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, - uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, - uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, - uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, - uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, - uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, - uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, - uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, - uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, - uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, - uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, - uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, - uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, - uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, - uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, - uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, - uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, - uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, - uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, - uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, - uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, - uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, - uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, - uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, - uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %122, 0;\n" - "wgmma.mma_async.sync.aligned.m64n240k32.s32.u8.s8 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87, " - " %88, %89, %90, %91, %92, %93, %94, %95, " - " %96, %97, %98, %99, %100, %101, %102, %103, " - " %104, %105, %106, %107, %108, %109, %110, %111, " - " %112, %113, %114, %115, %116, %117, %118, %119}," - " %120," - " %121," - " p;\n" - "}\n" - : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), - "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), - "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), - "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), - "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), - "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), - "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), - "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), - "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), - "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), - "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), - "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), - "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), - "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), - "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), - "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), - "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), - "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), - "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), - "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), - "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), - "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), - "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), - "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), - "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), - "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), - "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), - "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), - "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), - "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119) - : "l"(desc_a), - "l"(desc_b), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x240x32_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x240x32 TN S32+=U8*S8 -struct MMA_64x240x32_S32U8S8_SS_TN_SATURATE -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[120]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, - uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, - uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, - uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, - uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, - uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, - uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, - uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, - uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, - uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, - uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, - uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, - uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, - uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, - uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, - uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, - uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, - uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, - uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, - uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, - uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, - uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, - uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, - uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, - uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, - uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, - uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, - uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, - uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, - uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %122, 0;\n" - "wgmma.mma_async.sync.aligned.m64n240k32.s32.u8.s8.satfinite " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87, " - " %88, %89, %90, %91, %92, %93, %94, %95, " - " %96, %97, %98, %99, %100, %101, %102, %103, " - " %104, %105, %106, %107, %108, %109, %110, %111, " - " %112, %113, %114, %115, %116, %117, %118, %119}," - " %120," - " %121," - " p;\n" - "}\n" - : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), - "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), - "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), - "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), - "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), - "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), - "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), - "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), - "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), - "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), - "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), - "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), - "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), - "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), - "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), - "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), - "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), - "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), - "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), - "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), - "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), - "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), - "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), - "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), - "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), - "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), - "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), - "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), - "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), - "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119) - : "l"(desc_a), - "l"(desc_b), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x240x32_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// GMMA 64x256x32 TN S32+=U8*S8 -struct MMA_64x256x32_S32U8S8_SS_TN -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[128]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, - uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, - uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, - uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, - uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, - uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, - uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, - uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, - uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, - uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, - uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, - uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, - uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, - uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, - uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, - uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, - uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, - uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, - uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, - uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, - uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, - uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, - uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, - uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, - uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, - uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, - uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, - uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, - uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, - uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, - uint32_t & d120, uint32_t & d121, uint32_t & d122, uint32_t & d123, - uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %130, 0;\n" - "wgmma.mma_async.sync.aligned.m64n256k32.s32.u8.s8 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87, " - " %88, %89, %90, %91, %92, %93, %94, %95, " - " %96, %97, %98, %99, %100, %101, %102, %103, " - " %104, %105, %106, %107, %108, %109, %110, %111, " - " %112, %113, %114, %115, %116, %117, %118, %119, " - " %120, %121, %122, %123, %124, %125, %126, %127}," - " %128," - " %129," - " p;\n" - "}\n" - : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), - "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), - "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), - "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), - "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), - "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), - "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), - "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), - "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), - "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), - "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), - "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), - "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), - "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), - "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), - "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), - "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), - "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), - "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), - "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), - "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), - "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), - "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), - "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), - "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), - "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), - "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), - "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), - "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), - "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119), - "+r"(d120), "+r"(d121), "+r"(d122), "+r"(d123), - "+r"(d124), "+r"(d125), "+r"(d126), "+r"(d127) - : "l"(desc_a), - "l"(desc_b), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x256x32_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// GMMA 64x256x32 TN S32+=U8*S8 -struct MMA_64x256x32_S32U8S8_SS_TN_SATURATE -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[128]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, - uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, - uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, - uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, - uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, - uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, - uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, - uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, - uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, - uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, - uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, - uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, - uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, - uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, - uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, - uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, - uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, - uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, - uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, - uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, - uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, - uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, - uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, - uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, - uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, - uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, - uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, - uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, - uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, - uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, - uint32_t & d120, uint32_t & d121, uint32_t & d122, uint32_t & d123, - uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %130, 0;\n" - "wgmma.mma_async.sync.aligned.m64n256k32.s32.u8.s8.satfinite " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87, " - " %88, %89, %90, %91, %92, %93, %94, %95, " - " %96, %97, %98, %99, %100, %101, %102, %103, " - " %104, %105, %106, %107, %108, %109, %110, %111, " - " %112, %113, %114, %115, %116, %117, %118, %119, " - " %120, %121, %122, %123, %124, %125, %126, %127}," - " %128," - " %129," - " p;\n" - "}\n" - : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), - "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), - "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), - "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), - "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), - "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), - "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), - "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), - "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), - "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), - "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), - "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), - "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), - "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), - "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), - "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), - "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), - "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), - "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), - "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), - "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), - "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), - "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), - "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), - "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), - "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), - "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), - "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), - "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), - "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119), - "+r"(d120), "+r"(d121), "+r"(d122), "+r"(d123), - "+r"(d124), "+r"(d125), "+r"(d126), "+r"(d127) - : "l"(desc_a), - "l"(desc_b), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x256x32_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// GMMA 64x8x32 TN S32+=U8*S8 -struct MMA_64x8x32_S32U8S8_RS_TN -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[4]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, - uint64_t const& desc_b, - uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %9, 0;\n" - "wgmma.mma_async.sync.aligned.m64n8k32.s32.u8.s8 " - "{%0, %1, %2, %3}," - "{%4, %5, %6, %7}," - " %8," - " p;\n" - "}\n" - : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) - : "r"(a0), "r"(a1), "r"(a2), "r"(a3), - "l"(desc_b), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x8x32_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// GMMA 64x8x32 TN S32+=U8*S8 -struct MMA_64x8x32_S32U8S8_RS_TN_SATURATE -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[4]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, - uint64_t const& desc_b, - uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %9, 0;\n" - "wgmma.mma_async.sync.aligned.m64n8k32.s32.u8.s8.satfinite " - "{%0, %1, %2, %3}," - "{%4, %5, %6, %7}," - " %8," - " p;\n" - "}\n" - : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) - : "r"(a0), "r"(a1), "r"(a2), "r"(a3), - "l"(desc_b), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x8x32_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// GMMA 64x16x32 TN S32+=U8*S8 -struct MMA_64x16x32_S32U8S8_RS_TN -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[8]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, - uint64_t const& desc_b, - uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, - uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %13, 0;\n" - "wgmma.mma_async.sync.aligned.m64n16k32.s32.u8.s8 " - "{%0, %1, %2, %3, %4, %5, %6, %7}," - "{%8, %9, %10, %11}," - " %12," - " p;\n" - "}\n" - : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), - "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) - : "r"(a0), "r"(a1), "r"(a2), "r"(a3), - "l"(desc_b), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x16x32_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// GMMA 64x16x32 TN S32+=U8*S8 -struct MMA_64x16x32_S32U8S8_RS_TN_SATURATE -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[8]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, - uint64_t const& desc_b, - uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, - uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %13, 0;\n" - "wgmma.mma_async.sync.aligned.m64n16k32.s32.u8.s8.satfinite " - "{%0, %1, %2, %3, %4, %5, %6, %7}," - "{%8, %9, %10, %11}," - " %12," - " p;\n" - "}\n" - : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), - "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) - : "r"(a0), "r"(a1), "r"(a2), "r"(a3), - "l"(desc_b), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x16x32_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// GMMA 64x32x32 TN S32+=U8*S8 -struct MMA_64x32x32_S32U8S8_RS_TN -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[16]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %21, 0;\n" - "wgmma.mma_async.sync.aligned.m64n32k32.s32.u8.s8 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15}," - "{%16, %17, %18, %19}," - " %20," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), - "l"(desc_b), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x32x32_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// GMMA 64x32x32 TN S32+=U8*S8 -struct MMA_64x32x32_S32U8S8_RS_TN_SATURATE -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[16]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %21, 0;\n" - "wgmma.mma_async.sync.aligned.m64n32k32.s32.u8.s8.satfinite " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15}," - "{%16, %17, %18, %19}," - " %20," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), - "l"(desc_b), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x32x32_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x48x32 TN S32+=U8*S8 -struct MMA_64x48x32_S32U8S8_RS_TN -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[24]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %29, 0;\n" - "wgmma.mma_async.sync.aligned.m64n48k32.s32.u8.s8 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23}," - "{%24, %25, %26, %27}," - " %28," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), - "l"(desc_b), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x48x32_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x48x32 TN S32+=U8*S8 -struct MMA_64x48x32_S32U8S8_RS_TN_SATURATE -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[24]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %29, 0;\n" - "wgmma.mma_async.sync.aligned.m64n48k32.s32.u8.s8.satfinite " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23}," - "{%24, %25, %26, %27}," - " %28," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), - "l"(desc_b), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x48x32_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// GMMA 64x64x32 TN S32+=U8*S8 -struct MMA_64x64x32_S32U8S8_RS_TN -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[32]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %37, 0;\n" - "wgmma.mma_async.sync.aligned.m64n64k32.s32.u8.s8 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31}," - "{%32, %33, %34, %35}," - " %36," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), - "l"(desc_b), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x64x32_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// GMMA 64x64x32 TN S32+=U8*S8 -struct MMA_64x64x32_S32U8S8_RS_TN_SATURATE -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[32]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %37, 0;\n" - "wgmma.mma_async.sync.aligned.m64n64k32.s32.u8.s8.satfinite " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31}," - "{%32, %33, %34, %35}," - " %36," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), - "l"(desc_b), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x64x32_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x80x32 TN S32+=U8*S8 -struct MMA_64x80x32_S32U8S8_RS_TN -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[40]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %45, 0;\n" - "wgmma.mma_async.sync.aligned.m64n80k32.s32.u8.s8 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39}," - "{%40, %41, %42, %43}," - " %44," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), - "l"(desc_b), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x80x32_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x80x32 TN S32+=U8*S8 -struct MMA_64x80x32_S32U8S8_RS_TN_SATURATE -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[40]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %45, 0;\n" - "wgmma.mma_async.sync.aligned.m64n80k32.s32.u8.s8.satfinite " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39}," - "{%40, %41, %42, %43}," - " %44," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), - "l"(desc_b), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x80x32_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// GMMA 64x96x32 TN S32+=U8*S8 -struct MMA_64x96x32_S32U8S8_RS_TN -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[48]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %53, 0;\n" - "wgmma.mma_async.sync.aligned.m64n96k32.s32.u8.s8 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47}," - "{%48, %49, %50, %51}," - " %52," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), - "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), - "l"(desc_b), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x96x32_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// GMMA 64x96x32 TN S32+=U8*S8 -struct MMA_64x96x32_S32U8S8_RS_TN_SATURATE -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[48]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %53, 0;\n" - "wgmma.mma_async.sync.aligned.m64n96k32.s32.u8.s8.satfinite " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47}," - "{%48, %49, %50, %51}," - " %52," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), - "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), - "l"(desc_b), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x96x32_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x112x32 TN S32+=U8*S8 -struct MMA_64x112x32_S32U8S8_RS_TN -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[56]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, - uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, - uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %61, 0;\n" - "wgmma.mma_async.sync.aligned.m64n112k32.s32.u8.s8 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55}," - "{%56, %57, %58, %59}," - " %60," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), - "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), - "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), - "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), - "l"(desc_b), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x112x32_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x112x32 TN S32+=U8*S8 -struct MMA_64x112x32_S32U8S8_RS_TN_SATURATE -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[56]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, - uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, - uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %61, 0;\n" - "wgmma.mma_async.sync.aligned.m64n112k32.s32.u8.s8.satfinite " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55}," - "{%56, %57, %58, %59}," - " %60," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), - "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), - "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), - "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), - "l"(desc_b), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x112x32_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// GMMA 64x128x32 TN S32+=U8*S8 -struct MMA_64x128x32_S32U8S8_RS_TN -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[64]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, - uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, - uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, - uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, - uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %69, 0;\n" - "wgmma.mma_async.sync.aligned.m64n128k32.s32.u8.s8 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63}," - "{%64, %65, %66, %67}," - " %68," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), - "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), - "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), - "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), - "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), - "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), - "l"(desc_b), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x128x32_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// GMMA 64x128x32 TN S32+=U8*S8 -struct MMA_64x128x32_S32U8S8_RS_TN_SATURATE -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[64]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, - uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, - uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, - uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, - uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %69, 0;\n" - "wgmma.mma_async.sync.aligned.m64n128k32.s32.u8.s8.satfinite " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63}," - "{%64, %65, %66, %67}," - " %68," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), - "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), - "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), - "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), - "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), - "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), - "l"(desc_b), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x128x32_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x144x32 TN S32+=U8*S8 -struct MMA_64x144x32_S32U8S8_RS_TN -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[72]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, - uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, - uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, - uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, - uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, - uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, - uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %77, 0;\n" - "wgmma.mma_async.sync.aligned.m64n144k32.s32.u8.s8 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71}," - "{%72, %73, %74, %75}," - " %76," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), - "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), - "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), - "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), - "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), - "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), - "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), - "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), - "l"(desc_b), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x144x32_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x144x32 TN S32+=U8*S8 -struct MMA_64x144x32_S32U8S8_RS_TN_SATURATE -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[72]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, - uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, - uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, - uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, - uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, - uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, - uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %77, 0;\n" - "wgmma.mma_async.sync.aligned.m64n144k32.s32.u8.s8.satfinite " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71}," - "{%72, %73, %74, %75}," - " %76," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), - "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), - "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), - "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), - "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), - "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), - "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), - "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), - "l"(desc_b), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x144x32_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x160x32 TN S32+=U8*S8 -struct MMA_64x160x32_S32U8S8_RS_TN -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[80]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, - uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, - uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, - uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, - uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, - uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, - uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, - uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, - uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %85, 0;\n" - "wgmma.mma_async.sync.aligned.m64n160k32.s32.u8.s8 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79}," - "{%80, %81, %82, %83}," - " %84," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), - "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), - "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), - "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), - "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), - "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), - "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), - "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), - "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), - "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), - "l"(desc_b), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x160x32_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x160x32 TN S32+=U8*S8 -struct MMA_64x160x32_S32U8S8_RS_TN_SATURATE -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[80]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, - uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, - uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, - uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, - uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, - uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, - uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, - uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, - uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %85, 0;\n" - "wgmma.mma_async.sync.aligned.m64n160k32.s32.u8.s8.satfinite " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79}," - "{%80, %81, %82, %83}," - " %84," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), - "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), - "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), - "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), - "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), - "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), - "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), - "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), - "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), - "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), - "l"(desc_b), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x160x32_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x176x32 TN S32+=U8*S8 -struct MMA_64x176x32_S32U8S8_RS_TN -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[88]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, - uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, - uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, - uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, - uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, - uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, - uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, - uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, - uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, - uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, - uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %93, 0;\n" - "wgmma.mma_async.sync.aligned.m64n176k32.s32.u8.s8 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87}," - "{%88, %89, %90, %91}," - " %92," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), - "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), - "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), - "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), - "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), - "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), - "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), - "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), - "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), - "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), - "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), - "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), - "l"(desc_b), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x176x32_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x176x32 TN S32+=U8*S8 -struct MMA_64x176x32_S32U8S8_RS_TN_SATURATE -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[88]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, - uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, - uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, - uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, - uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, - uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, - uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, - uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, - uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, - uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, - uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %93, 0;\n" - "wgmma.mma_async.sync.aligned.m64n176k32.s32.u8.s8.satfinite " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87}," - "{%88, %89, %90, %91}," - " %92," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), - "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), - "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), - "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), - "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), - "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), - "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), - "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), - "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), - "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), - "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), - "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), - "l"(desc_b), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x176x32_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// GMMA 64x192x32 TN S32+=U8*S8 -struct MMA_64x192x32_S32U8S8_RS_TN -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[96]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, - uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, - uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, - uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, - uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, - uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, - uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, - uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, - uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, - uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, - uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, - uint32_t & d88, uint32_t & d89, uint32_t & d90, uint32_t & d91, - uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %101, 0;\n" - "wgmma.mma_async.sync.aligned.m64n192k32.s32.u8.s8 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87, " - " %88, %89, %90, %91, %92, %93, %94, %95}," - "{%96, %97, %98, %99}," - " %100," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), - "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), - "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), - "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), - "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), - "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), - "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), - "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), - "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), - "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), - "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), - "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87), - "+r"(d88), "+r"(d89), "+r"(d90), "+r"(d91), - "+r"(d92), "+r"(d93), "+r"(d94), "+r"(d95) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), - "l"(desc_b), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x192x32_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// GMMA 64x192x32 TN S32+=U8*S8 -struct MMA_64x192x32_S32U8S8_RS_TN_SATURATE -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[96]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, - uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, - uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, - uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, - uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, - uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, - uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, - uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, - uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, - uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, - uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, - uint32_t & d88, uint32_t & d89, uint32_t & d90, uint32_t & d91, - uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %101, 0;\n" - "wgmma.mma_async.sync.aligned.m64n192k32.s32.u8.s8.satfinite " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87, " - " %88, %89, %90, %91, %92, %93, %94, %95}," - "{%96, %97, %98, %99}," - " %100," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), - "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), - "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), - "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), - "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), - "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), - "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), - "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), - "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), - "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), - "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), - "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87), - "+r"(d88), "+r"(d89), "+r"(d90), "+r"(d91), - "+r"(d92), "+r"(d93), "+r"(d94), "+r"(d95) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), - "l"(desc_b), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x192x32_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x208x32 TN S32+=U8*S8 -struct MMA_64x208x32_S32U8S8_RS_TN -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[104]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, - uint64_t const& desc_b, - uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, - uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, - uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, - uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, - uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, - uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, - uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, - uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, - uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, - uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, - uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, - uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, - uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, - uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, - uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, - uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, - uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, - uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, - uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, - uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, - uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, - uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, - uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, - uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, - uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, - uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %109, 0;\n" - "wgmma.mma_async.sync.aligned.m64n208k32.s32.u8.s8 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87, " - " %88, %89, %90, %91, %92, %93, %94, %95, " - " %96, %97, %98, %99, %100, %101, %102, %103}," - "{%104, %105, %106, %107}," - " %108," - " p;\n" - "}\n" - : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), - "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), - "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), - "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), - "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), - "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), - "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), - "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), - "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), - "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), - "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), - "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), - "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), - "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), - "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), - "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), - "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), - "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), - "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), - "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), - "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), - "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), - "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), - "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), - "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), - "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103) - : "r"(a000), "r"(a001), "r"(a002), "r"(a003), - "l"(desc_b), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x208x32_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x208x32 TN S32+=U8*S8 -struct MMA_64x208x32_S32U8S8_RS_TN_SATURATE -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[104]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, - uint64_t const& desc_b, - uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, - uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, - uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, - uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, - uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, - uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, - uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, - uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, - uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, - uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, - uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, - uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, - uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, - uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, - uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, - uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, - uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, - uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, - uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, - uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, - uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, - uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, - uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, - uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, - uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, - uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %109, 0;\n" - "wgmma.mma_async.sync.aligned.m64n208k32.s32.u8.s8.satfinite " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87, " - " %88, %89, %90, %91, %92, %93, %94, %95, " - " %96, %97, %98, %99, %100, %101, %102, %103}," - "{%104, %105, %106, %107}," - " %108," - " p;\n" - "}\n" - : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), - "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), - "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), - "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), - "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), - "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), - "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), - "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), - "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), - "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), - "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), - "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), - "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), - "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), - "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), - "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), - "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), - "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), - "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), - "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), - "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), - "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), - "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), - "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), - "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), - "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103) - : "r"(a000), "r"(a001), "r"(a002), "r"(a003), - "l"(desc_b), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x208x32_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x224x32 TN S32+=U8*S8 -struct MMA_64x224x32_S32U8S8_RS_TN -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[112]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, - uint64_t const& desc_b, - uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, - uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, - uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, - uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, - uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, - uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, - uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, - uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, - uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, - uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, - uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, - uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, - uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, - uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, - uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, - uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, - uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, - uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, - uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, - uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, - uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, - uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, - uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, - uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, - uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, - uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, - uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, - uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %117, 0;\n" - "wgmma.mma_async.sync.aligned.m64n224k32.s32.u8.s8 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87, " - " %88, %89, %90, %91, %92, %93, %94, %95, " - " %96, %97, %98, %99, %100, %101, %102, %103, " - " %104, %105, %106, %107, %108, %109, %110, %111}," - "{%112, %113, %114, %115}," - " %116," - " p;\n" - "}\n" - : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), - "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), - "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), - "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), - "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), - "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), - "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), - "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), - "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), - "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), - "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), - "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), - "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), - "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), - "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), - "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), - "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), - "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), - "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), - "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), - "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), - "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), - "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), - "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), - "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), - "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), - "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), - "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111) - : "r"(a000), "r"(a001), "r"(a002), "r"(a003), - "l"(desc_b), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x224x32_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x224x32 TN S32+=U8*S8 -struct MMA_64x224x32_S32U8S8_RS_TN_SATURATE -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[112]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, - uint64_t const& desc_b, - uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, - uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, - uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, - uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, - uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, - uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, - uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, - uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, - uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, - uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, - uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, - uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, - uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, - uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, - uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, - uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, - uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, - uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, - uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, - uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, - uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, - uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, - uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, - uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, - uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, - uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, - uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, - uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %117, 0;\n" - "wgmma.mma_async.sync.aligned.m64n224k32.s32.u8.s8.satfinite " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87, " - " %88, %89, %90, %91, %92, %93, %94, %95, " - " %96, %97, %98, %99, %100, %101, %102, %103, " - " %104, %105, %106, %107, %108, %109, %110, %111}," - "{%112, %113, %114, %115}," - " %116," - " p;\n" - "}\n" - : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), - "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), - "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), - "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), - "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), - "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), - "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), - "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), - "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), - "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), - "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), - "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), - "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), - "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), - "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), - "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), - "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), - "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), - "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), - "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), - "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), - "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), - "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), - "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), - "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), - "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), - "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), - "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111) - : "r"(a000), "r"(a001), "r"(a002), "r"(a003), - "l"(desc_b), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x224x32_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x240x32 TN S32+=U8*S8 -struct MMA_64x240x32_S32U8S8_RS_TN -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[120]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, - uint64_t const& desc_b, - uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, - uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, - uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, - uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, - uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, - uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, - uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, - uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, - uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, - uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, - uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, - uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, - uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, - uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, - uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, - uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, - uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, - uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, - uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, - uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, - uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, - uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, - uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, - uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, - uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, - uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, - uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, - uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, - uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, - uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %125, 0;\n" - "wgmma.mma_async.sync.aligned.m64n240k32.s32.u8.s8 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87, " - " %88, %89, %90, %91, %92, %93, %94, %95, " - " %96, %97, %98, %99, %100, %101, %102, %103, " - " %104, %105, %106, %107, %108, %109, %110, %111, " - " %112, %113, %114, %115, %116, %117, %118, %119}," - "{%120, %121, %122, %123}," - " %124," - " p;\n" - "}\n" - : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), - "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), - "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), - "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), - "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), - "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), - "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), - "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), - "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), - "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), - "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), - "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), - "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), - "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), - "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), - "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), - "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), - "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), - "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), - "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), - "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), - "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), - "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), - "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), - "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), - "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), - "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), - "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), - "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), - "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119) - : "r"(a000), "r"(a001), "r"(a002), "r"(a003), - "l"(desc_b), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x240x32_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x240x32 TN S32+=U8*S8 -struct MMA_64x240x32_S32U8S8_RS_TN_SATURATE -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[120]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, - uint64_t const& desc_b, - uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, - uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, - uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, - uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, - uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, - uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, - uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, - uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, - uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, - uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, - uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, - uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, - uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, - uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, - uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, - uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, - uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, - uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, - uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, - uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, - uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, - uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, - uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, - uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, - uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, - uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, - uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, - uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, - uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, - uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %125, 0;\n" - "wgmma.mma_async.sync.aligned.m64n240k32.s32.u8.s8.satfinite " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87, " - " %88, %89, %90, %91, %92, %93, %94, %95, " - " %96, %97, %98, %99, %100, %101, %102, %103, " - " %104, %105, %106, %107, %108, %109, %110, %111, " - " %112, %113, %114, %115, %116, %117, %118, %119}," - "{%120, %121, %122, %123}," - " %124," - " p;\n" - "}\n" - : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), - "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), - "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), - "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), - "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), - "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), - "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), - "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), - "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), - "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), - "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), - "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), - "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), - "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), - "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), - "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), - "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), - "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), - "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), - "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), - "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), - "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), - "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), - "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), - "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), - "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), - "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), - "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), - "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), - "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119) - : "r"(a000), "r"(a001), "r"(a002), "r"(a003), - "l"(desc_b), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x240x32_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// GMMA 64x256x32 TN S32+=U8*S8 -struct MMA_64x256x32_S32U8S8_RS_TN -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[128]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, - uint64_t const& desc_b, - uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, - uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, - uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, - uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, - uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, - uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, - uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, - uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, - uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, - uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, - uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, - uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, - uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, - uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, - uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, - uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, - uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, - uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, - uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, - uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, - uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, - uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, - uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, - uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, - uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, - uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, - uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, - uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, - uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, - uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, - uint32_t & d120, uint32_t & d121, uint32_t & d122, uint32_t & d123, - uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %133, 0;\n" - "wgmma.mma_async.sync.aligned.m64n256k32.s32.u8.s8 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87, " - " %88, %89, %90, %91, %92, %93, %94, %95, " - " %96, %97, %98, %99, %100, %101, %102, %103, " - " %104, %105, %106, %107, %108, %109, %110, %111, " - " %112, %113, %114, %115, %116, %117, %118, %119, " - " %120, %121, %122, %123, %124, %125, %126, %127}," - "{%128, %129, %130, %131}," - " %132," - " p;\n" - "}\n" - : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), - "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), - "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), - "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), - "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), - "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), - "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), - "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), - "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), - "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), - "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), - "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), - "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), - "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), - "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), - "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), - "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), - "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), - "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), - "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), - "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), - "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), - "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), - "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), - "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), - "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), - "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), - "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), - "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), - "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119), - "+r"(d120), "+r"(d121), "+r"(d122), "+r"(d123), - "+r"(d124), "+r"(d125), "+r"(d126), "+r"(d127) - : "r"(a000), "r"(a001), "r"(a002), "r"(a003), - "l"(desc_b), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x256x32_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// GMMA 64x256x32 TN S32+=U8*S8 -struct MMA_64x256x32_S32U8S8_RS_TN_SATURATE -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[128]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, - uint64_t const& desc_b, - uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, - uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, - uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, - uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, - uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, - uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, - uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, - uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, - uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, - uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, - uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, - uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, - uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, - uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, - uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, - uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, - uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, - uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, - uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, - uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, - uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, - uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, - uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, - uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, - uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, - uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, - uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, - uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, - uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, - uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, - uint32_t & d120, uint32_t & d121, uint32_t & d122, uint32_t & d123, - uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %133, 0;\n" - "wgmma.mma_async.sync.aligned.m64n256k32.s32.u8.s8.satfinite " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87, " - " %88, %89, %90, %91, %92, %93, %94, %95, " - " %96, %97, %98, %99, %100, %101, %102, %103, " - " %104, %105, %106, %107, %108, %109, %110, %111, " - " %112, %113, %114, %115, %116, %117, %118, %119, " - " %120, %121, %122, %123, %124, %125, %126, %127}," - "{%128, %129, %130, %131}," - " %132," - " p;\n" - "}\n" - : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), - "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), - "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), - "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), - "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), - "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), - "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), - "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), - "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), - "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), - "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), - "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), - "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), - "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), - "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), - "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), - "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), - "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), - "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), - "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), - "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), - "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), - "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), - "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), - "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), - "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), - "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), - "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), - "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), - "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119), - "+r"(d120), "+r"(d121), "+r"(d122), "+r"(d123), - "+r"(d124), "+r"(d125), "+r"(d126), "+r"(d127) - : "r"(a000), "r"(a001), "r"(a002), "r"(a003), - "l"(desc_b), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x256x32_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// GMMA 64x8x32 TN S32+=U8*U8 -struct MMA_64x8x32_S32U8U8_SS_TN -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[4]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %6, 0;\n" - "wgmma.mma_async.sync.aligned.m64n8k32.s32.u8.u8 " - "{%0, %1, %2, %3}," - " %4," - " %5," - " p;\n" - "}\n" - : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) - : "l"(desc_a), - "l"(desc_b), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x8x32_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// GMMA 64x8x32 TN S32+=U8*U8 -struct MMA_64x8x32_S32U8U8_SS_TN_SATURATE -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[4]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %6, 0;\n" - "wgmma.mma_async.sync.aligned.m64n8k32.s32.u8.u8.satfinite " - "{%0, %1, %2, %3}," - " %4," - " %5," - " p;\n" - "}\n" - : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) - : "l"(desc_a), - "l"(desc_b), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x8x32_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// GMMA 64x16x32 TN S32+=U8*U8 -struct MMA_64x16x32_S32U8U8_SS_TN -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[8]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, - uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %10, 0;\n" - "wgmma.mma_async.sync.aligned.m64n16k32.s32.u8.u8 " - "{%0, %1, %2, %3, %4, %5, %6, %7}," - " %8," - " %9," - " p;\n" - "}\n" - : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), - "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) - : "l"(desc_a), - "l"(desc_b), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x16x32_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// GMMA 64x16x32 TN S32+=U8*U8 -struct MMA_64x16x32_S32U8U8_SS_TN_SATURATE -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[8]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, - uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %10, 0;\n" - "wgmma.mma_async.sync.aligned.m64n16k32.s32.u8.u8.satfinite " - "{%0, %1, %2, %3, %4, %5, %6, %7}," - " %8," - " %9," - " p;\n" - "}\n" - : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), - "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) - : "l"(desc_a), - "l"(desc_b), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x16x32_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// GMMA 64x32x32 TN S32+=U8*U8 -struct MMA_64x32x32_S32U8U8_SS_TN -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[16]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %18, 0;\n" - "wgmma.mma_async.sync.aligned.m64n32k32.s32.u8.u8 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15}," - " %16," - " %17," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) - : "l"(desc_a), - "l"(desc_b), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x32x32_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// GMMA 64x32x32 TN S32+=U8*U8 -struct MMA_64x32x32_S32U8U8_SS_TN_SATURATE -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[16]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %18, 0;\n" - "wgmma.mma_async.sync.aligned.m64n32k32.s32.u8.u8.satfinite " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15}," - " %16," - " %17," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) - : "l"(desc_a), - "l"(desc_b), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x32x32_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x48x32 TN S32+=U8*U8 -struct MMA_64x48x32_S32U8U8_SS_TN -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[24]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %26, 0;\n" - "wgmma.mma_async.sync.aligned.m64n48k32.s32.u8.u8 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23}," - " %24," - " %25," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) - : "l"(desc_a), - "l"(desc_b), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x48x32_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x48x32 TN S32+=U8*U8 -struct MMA_64x48x32_S32U8U8_SS_TN_SATURATE -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[24]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %26, 0;\n" - "wgmma.mma_async.sync.aligned.m64n48k32.s32.u8.u8.satfinite " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23}," - " %24," - " %25," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) - : "l"(desc_a), - "l"(desc_b), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x48x32_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// GMMA 64x64x32 TN S32+=U8*U8 -struct MMA_64x64x32_S32U8U8_SS_TN -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[32]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %34, 0;\n" - "wgmma.mma_async.sync.aligned.m64n64k32.s32.u8.u8 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31}," - " %32," - " %33," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) - : "l"(desc_a), - "l"(desc_b), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x64x32_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// GMMA 64x64x32 TN S32+=U8*U8 -struct MMA_64x64x32_S32U8U8_SS_TN_SATURATE -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[32]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %34, 0;\n" - "wgmma.mma_async.sync.aligned.m64n64k32.s32.u8.u8.satfinite " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31}," - " %32," - " %33," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) - : "l"(desc_a), - "l"(desc_b), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x64x32_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x80x32 TN S32+=U8*U8 -struct MMA_64x80x32_S32U8U8_SS_TN -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[40]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %42, 0;\n" - "wgmma.mma_async.sync.aligned.m64n80k32.s32.u8.u8 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39}," - " %40," - " %41," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) - : "l"(desc_a), - "l"(desc_b), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x80x32_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x80x32 TN S32+=U8*U8 -struct MMA_64x80x32_S32U8U8_SS_TN_SATURATE -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[40]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %42, 0;\n" - "wgmma.mma_async.sync.aligned.m64n80k32.s32.u8.u8.satfinite " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39}," - " %40," - " %41," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) - : "l"(desc_a), - "l"(desc_b), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x80x32_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// GMMA 64x96x32 TN S32+=U8*U8 -struct MMA_64x96x32_S32U8U8_SS_TN -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[48]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %50, 0;\n" - "wgmma.mma_async.sync.aligned.m64n96k32.s32.u8.u8 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47}," - " %48," - " %49," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), - "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) - : "l"(desc_a), - "l"(desc_b), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x96x32_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// GMMA 64x96x32 TN S32+=U8*U8 -struct MMA_64x96x32_S32U8U8_SS_TN_SATURATE -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[48]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %50, 0;\n" - "wgmma.mma_async.sync.aligned.m64n96k32.s32.u8.u8.satfinite " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47}," - " %48," - " %49," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), - "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) - : "l"(desc_a), - "l"(desc_b), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x96x32_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x112x32 TN S32+=U8*U8 -struct MMA_64x112x32_S32U8U8_SS_TN -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[56]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, - uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, - uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %58, 0;\n" - "wgmma.mma_async.sync.aligned.m64n112k32.s32.u8.u8 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55}," - " %56," - " %57," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), - "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), - "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), - "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) - : "l"(desc_a), - "l"(desc_b), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x112x32_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x112x32 TN S32+=U8*U8 -struct MMA_64x112x32_S32U8U8_SS_TN_SATURATE -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[56]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, - uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, - uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %58, 0;\n" - "wgmma.mma_async.sync.aligned.m64n112k32.s32.u8.u8.satfinite " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55}," - " %56," - " %57," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), - "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), - "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), - "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) - : "l"(desc_a), - "l"(desc_b), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x112x32_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// GMMA 64x128x32 TN S32+=U8*U8 -struct MMA_64x128x32_S32U8U8_SS_TN -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[64]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, - uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, - uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, - uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, - uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %66, 0;\n" - "wgmma.mma_async.sync.aligned.m64n128k32.s32.u8.u8 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63}," - " %64," - " %65," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), - "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), - "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), - "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), - "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), - "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) - : "l"(desc_a), - "l"(desc_b), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x128x32_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// GMMA 64x128x32 TN S32+=U8*U8 -struct MMA_64x128x32_S32U8U8_SS_TN_SATURATE -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[64]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, - uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, - uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, - uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, - uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %66, 0;\n" - "wgmma.mma_async.sync.aligned.m64n128k32.s32.u8.u8.satfinite " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63}," - " %64," - " %65," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), - "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), - "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), - "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), - "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), - "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) - : "l"(desc_a), - "l"(desc_b), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x128x32_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x144x32 TN S32+=U8*U8 -struct MMA_64x144x32_S32U8U8_SS_TN -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[72]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, - uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, - uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, - uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, - uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, - uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, - uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %74, 0;\n" - "wgmma.mma_async.sync.aligned.m64n144k32.s32.u8.u8 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71}," - " %72," - " %73," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), - "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), - "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), - "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), - "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), - "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), - "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), - "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71) - : "l"(desc_a), - "l"(desc_b), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x144x32_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x144x32 TN S32+=U8*U8 -struct MMA_64x144x32_S32U8U8_SS_TN_SATURATE -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[72]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, - uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, - uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, - uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, - uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, - uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, - uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %74, 0;\n" - "wgmma.mma_async.sync.aligned.m64n144k32.s32.u8.u8.satfinite " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71}," - " %72," - " %73," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), - "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), - "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), - "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), - "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), - "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), - "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), - "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71) - : "l"(desc_a), - "l"(desc_b), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x144x32_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x160x32 TN S32+=U8*U8 -struct MMA_64x160x32_S32U8U8_SS_TN -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[80]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, - uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, - uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, - uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, - uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, - uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, - uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, - uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, - uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %82, 0;\n" - "wgmma.mma_async.sync.aligned.m64n160k32.s32.u8.u8 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79}," - " %80," - " %81," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), - "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), - "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), - "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), - "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), - "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), - "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), - "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), - "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), - "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79) - : "l"(desc_a), - "l"(desc_b), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x160x32_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x160x32 TN S32+=U8*U8 -struct MMA_64x160x32_S32U8U8_SS_TN_SATURATE -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[80]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, - uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, - uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, - uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, - uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, - uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, - uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, - uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, - uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %82, 0;\n" - "wgmma.mma_async.sync.aligned.m64n160k32.s32.u8.u8.satfinite " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79}," - " %80," - " %81," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), - "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), - "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), - "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), - "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), - "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), - "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), - "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), - "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), - "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79) - : "l"(desc_a), - "l"(desc_b), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x160x32_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x176x32 TN S32+=U8*U8 -struct MMA_64x176x32_S32U8U8_SS_TN -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[88]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, - uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, - uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, - uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, - uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, - uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, - uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, - uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, - uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, - uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, - uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %90, 0;\n" - "wgmma.mma_async.sync.aligned.m64n176k32.s32.u8.u8 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87}," - " %88," - " %89," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), - "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), - "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), - "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), - "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), - "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), - "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), - "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), - "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), - "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), - "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), - "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87) - : "l"(desc_a), - "l"(desc_b), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x176x32_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x176x32 TN S32+=U8*U8 -struct MMA_64x176x32_S32U8U8_SS_TN_SATURATE -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[88]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, - uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, - uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, - uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, - uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, - uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, - uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, - uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, - uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, - uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, - uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %90, 0;\n" - "wgmma.mma_async.sync.aligned.m64n176k32.s32.u8.u8.satfinite " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87}," - " %88," - " %89," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), - "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), - "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), - "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), - "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), - "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), - "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), - "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), - "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), - "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), - "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), - "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87) - : "l"(desc_a), - "l"(desc_b), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x176x32_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// GMMA 64x192x32 TN S32+=U8*U8 -struct MMA_64x192x32_S32U8U8_SS_TN -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[96]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, - uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, - uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, - uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, - uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, - uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, - uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, - uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, - uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, - uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, - uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, - uint32_t & d88, uint32_t & d89, uint32_t & d90, uint32_t & d91, - uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %98, 0;\n" - "wgmma.mma_async.sync.aligned.m64n192k32.s32.u8.u8 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87, " - " %88, %89, %90, %91, %92, %93, %94, %95}," - " %96," - " %97," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), - "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), - "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), - "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), - "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), - "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), - "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), - "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), - "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), - "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), - "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), - "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87), - "+r"(d88), "+r"(d89), "+r"(d90), "+r"(d91), - "+r"(d92), "+r"(d93), "+r"(d94), "+r"(d95) - : "l"(desc_a), - "l"(desc_b), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x192x32_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// GMMA 64x192x32 TN S32+=U8*U8 -struct MMA_64x192x32_S32U8U8_SS_TN_SATURATE -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[96]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, - uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, - uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, - uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, - uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, - uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, - uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, - uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, - uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, - uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, - uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, - uint32_t & d88, uint32_t & d89, uint32_t & d90, uint32_t & d91, - uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %98, 0;\n" - "wgmma.mma_async.sync.aligned.m64n192k32.s32.u8.u8.satfinite " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87, " - " %88, %89, %90, %91, %92, %93, %94, %95}," - " %96," - " %97," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), - "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), - "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), - "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), - "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), - "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), - "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), - "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), - "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), - "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), - "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), - "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87), - "+r"(d88), "+r"(d89), "+r"(d90), "+r"(d91), - "+r"(d92), "+r"(d93), "+r"(d94), "+r"(d95) - : "l"(desc_a), - "l"(desc_b), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x192x32_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x208x32 TN S32+=U8*U8 -struct MMA_64x208x32_S32U8U8_SS_TN -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[104]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, - uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, - uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, - uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, - uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, - uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, - uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, - uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, - uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, - uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, - uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, - uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, - uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, - uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, - uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, - uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, - uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, - uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, - uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, - uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, - uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, - uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, - uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, - uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, - uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, - uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %106, 0;\n" - "wgmma.mma_async.sync.aligned.m64n208k32.s32.u8.u8 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87, " - " %88, %89, %90, %91, %92, %93, %94, %95, " - " %96, %97, %98, %99, %100, %101, %102, %103}," - " %104," - " %105," - " p;\n" - "}\n" - : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), - "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), - "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), - "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), - "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), - "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), - "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), - "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), - "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), - "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), - "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), - "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), - "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), - "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), - "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), - "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), - "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), - "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), - "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), - "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), - "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), - "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), - "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), - "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), - "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), - "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103) - : "l"(desc_a), - "l"(desc_b), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x208x32_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x208x32 TN S32+=U8*U8 -struct MMA_64x208x32_S32U8U8_SS_TN_SATURATE -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[104]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, - uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, - uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, - uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, - uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, - uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, - uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, - uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, - uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, - uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, - uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, - uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, - uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, - uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, - uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, - uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, - uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, - uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, - uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, - uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, - uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, - uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, - uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, - uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, - uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, - uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %106, 0;\n" - "wgmma.mma_async.sync.aligned.m64n208k32.s32.u8.u8.satfinite " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87, " - " %88, %89, %90, %91, %92, %93, %94, %95, " - " %96, %97, %98, %99, %100, %101, %102, %103}," - " %104," - " %105," - " p;\n" - "}\n" - : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), - "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), - "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), - "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), - "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), - "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), - "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), - "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), - "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), - "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), - "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), - "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), - "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), - "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), - "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), - "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), - "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), - "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), - "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), - "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), - "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), - "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), - "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), - "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), - "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), - "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103) - : "l"(desc_a), - "l"(desc_b), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x208x32_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x224x32 TN S32+=U8*U8 -struct MMA_64x224x32_S32U8U8_SS_TN -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[112]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, - uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, - uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, - uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, - uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, - uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, - uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, - uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, - uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, - uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, - uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, - uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, - uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, - uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, - uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, - uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, - uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, - uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, - uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, - uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, - uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, - uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, - uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, - uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, - uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, - uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, - uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, - uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %114, 0;\n" - "wgmma.mma_async.sync.aligned.m64n224k32.s32.u8.u8 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87, " - " %88, %89, %90, %91, %92, %93, %94, %95, " - " %96, %97, %98, %99, %100, %101, %102, %103, " - " %104, %105, %106, %107, %108, %109, %110, %111}," - " %112," - " %113," - " p;\n" - "}\n" - : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), - "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), - "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), - "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), - "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), - "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), - "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), - "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), - "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), - "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), - "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), - "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), - "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), - "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), - "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), - "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), - "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), - "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), - "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), - "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), - "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), - "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), - "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), - "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), - "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), - "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), - "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), - "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111) - : "l"(desc_a), - "l"(desc_b), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x224x32_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x224x32 TN S32+=U8*U8 -struct MMA_64x224x32_S32U8U8_SS_TN_SATURATE -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[112]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, - uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, - uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, - uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, - uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, - uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, - uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, - uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, - uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, - uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, - uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, - uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, - uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, - uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, - uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, - uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, - uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, - uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, - uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, - uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, - uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, - uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, - uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, - uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, - uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, - uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, - uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, - uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %114, 0;\n" - "wgmma.mma_async.sync.aligned.m64n224k32.s32.u8.u8.satfinite " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87, " - " %88, %89, %90, %91, %92, %93, %94, %95, " - " %96, %97, %98, %99, %100, %101, %102, %103, " - " %104, %105, %106, %107, %108, %109, %110, %111}," - " %112," - " %113," - " p;\n" - "}\n" - : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), - "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), - "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), - "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), - "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), - "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), - "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), - "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), - "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), - "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), - "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), - "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), - "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), - "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), - "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), - "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), - "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), - "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), - "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), - "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), - "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), - "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), - "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), - "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), - "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), - "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), - "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), - "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111) - : "l"(desc_a), - "l"(desc_b), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x224x32_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x240x32 TN S32+=U8*U8 -struct MMA_64x240x32_S32U8U8_SS_TN -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[120]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, - uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, - uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, - uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, - uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, - uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, - uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, - uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, - uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, - uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, - uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, - uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, - uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, - uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, - uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, - uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, - uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, - uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, - uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, - uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, - uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, - uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, - uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, - uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, - uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, - uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, - uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, - uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, - uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, - uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %122, 0;\n" - "wgmma.mma_async.sync.aligned.m64n240k32.s32.u8.u8 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87, " - " %88, %89, %90, %91, %92, %93, %94, %95, " - " %96, %97, %98, %99, %100, %101, %102, %103, " - " %104, %105, %106, %107, %108, %109, %110, %111, " - " %112, %113, %114, %115, %116, %117, %118, %119}," - " %120," - " %121," - " p;\n" - "}\n" - : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), - "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), - "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), - "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), - "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), - "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), - "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), - "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), - "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), - "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), - "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), - "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), - "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), - "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), - "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), - "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), - "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), - "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), - "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), - "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), - "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), - "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), - "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), - "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), - "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), - "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), - "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), - "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), - "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), - "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119) - : "l"(desc_a), - "l"(desc_b), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x240x32_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x240x32 TN S32+=U8*U8 -struct MMA_64x240x32_S32U8U8_SS_TN_SATURATE -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[120]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, - uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, - uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, - uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, - uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, - uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, - uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, - uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, - uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, - uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, - uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, - uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, - uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, - uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, - uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, - uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, - uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, - uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, - uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, - uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, - uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, - uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, - uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, - uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, - uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, - uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, - uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, - uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, - uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, - uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %122, 0;\n" - "wgmma.mma_async.sync.aligned.m64n240k32.s32.u8.u8.satfinite " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87, " - " %88, %89, %90, %91, %92, %93, %94, %95, " - " %96, %97, %98, %99, %100, %101, %102, %103, " - " %104, %105, %106, %107, %108, %109, %110, %111, " - " %112, %113, %114, %115, %116, %117, %118, %119}," - " %120," - " %121," - " p;\n" - "}\n" - : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), - "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), - "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), - "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), - "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), - "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), - "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), - "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), - "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), - "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), - "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), - "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), - "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), - "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), - "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), - "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), - "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), - "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), - "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), - "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), - "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), - "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), - "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), - "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), - "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), - "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), - "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), - "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), - "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), - "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119) - : "l"(desc_a), - "l"(desc_b), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x240x32_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// GMMA 64x256x32 TN S32+=U8*U8 -struct MMA_64x256x32_S32U8U8_SS_TN -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[128]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, - uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, - uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, - uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, - uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, - uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, - uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, - uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, - uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, - uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, - uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, - uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, - uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, - uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, - uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, - uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, - uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, - uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, - uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, - uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, - uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, - uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, - uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, - uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, - uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, - uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, - uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, - uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, - uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, - uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, - uint32_t & d120, uint32_t & d121, uint32_t & d122, uint32_t & d123, - uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %130, 0;\n" - "wgmma.mma_async.sync.aligned.m64n256k32.s32.u8.u8 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87, " - " %88, %89, %90, %91, %92, %93, %94, %95, " - " %96, %97, %98, %99, %100, %101, %102, %103, " - " %104, %105, %106, %107, %108, %109, %110, %111, " - " %112, %113, %114, %115, %116, %117, %118, %119, " - " %120, %121, %122, %123, %124, %125, %126, %127}," - " %128," - " %129," - " p;\n" - "}\n" - : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), - "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), - "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), - "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), - "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), - "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), - "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), - "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), - "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), - "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), - "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), - "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), - "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), - "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), - "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), - "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), - "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), - "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), - "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), - "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), - "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), - "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), - "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), - "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), - "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), - "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), - "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), - "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), - "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), - "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119), - "+r"(d120), "+r"(d121), "+r"(d122), "+r"(d123), - "+r"(d124), "+r"(d125), "+r"(d126), "+r"(d127) - : "l"(desc_a), - "l"(desc_b), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x256x32_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// GMMA 64x256x32 TN S32+=U8*U8 -struct MMA_64x256x32_S32U8U8_SS_TN_SATURATE -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[128]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, - uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, - uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, - uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, - uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, - uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, - uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, - uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, - uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, - uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, - uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, - uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, - uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, - uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, - uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, - uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, - uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, - uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, - uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, - uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, - uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, - uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, - uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, - uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, - uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, - uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, - uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, - uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, - uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, - uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, - uint32_t & d120, uint32_t & d121, uint32_t & d122, uint32_t & d123, - uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %130, 0;\n" - "wgmma.mma_async.sync.aligned.m64n256k32.s32.u8.u8.satfinite " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87, " - " %88, %89, %90, %91, %92, %93, %94, %95, " - " %96, %97, %98, %99, %100, %101, %102, %103, " - " %104, %105, %106, %107, %108, %109, %110, %111, " - " %112, %113, %114, %115, %116, %117, %118, %119, " - " %120, %121, %122, %123, %124, %125, %126, %127}," - " %128," - " %129," - " p;\n" - "}\n" - : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), - "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), - "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), - "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), - "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), - "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), - "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), - "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), - "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), - "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), - "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), - "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), - "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), - "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), - "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), - "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), - "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), - "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), - "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), - "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), - "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), - "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), - "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), - "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), - "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), - "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), - "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), - "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), - "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), - "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119), - "+r"(d120), "+r"(d121), "+r"(d122), "+r"(d123), - "+r"(d124), "+r"(d125), "+r"(d126), "+r"(d127) - : "l"(desc_a), - "l"(desc_b), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x256x32_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// GMMA 64x8x32 TN S32+=U8*U8 -struct MMA_64x8x32_S32U8U8_RS_TN -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[4]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, - uint64_t const& desc_b, - uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %9, 0;\n" - "wgmma.mma_async.sync.aligned.m64n8k32.s32.u8.u8 " - "{%0, %1, %2, %3}," - "{%4, %5, %6, %7}," - " %8," - " p;\n" - "}\n" - : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) - : "r"(a0), "r"(a1), "r"(a2), "r"(a3), - "l"(desc_b), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x8x32_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// GMMA 64x8x32 TN S32+=U8*U8 -struct MMA_64x8x32_S32U8U8_RS_TN_SATURATE -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[4]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, - uint64_t const& desc_b, - uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %9, 0;\n" - "wgmma.mma_async.sync.aligned.m64n8k32.s32.u8.u8.satfinite " - "{%0, %1, %2, %3}," - "{%4, %5, %6, %7}," - " %8," - " p;\n" - "}\n" - : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) - : "r"(a0), "r"(a1), "r"(a2), "r"(a3), - "l"(desc_b), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x8x32_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// GMMA 64x16x32 TN S32+=U8*U8 -struct MMA_64x16x32_S32U8U8_RS_TN -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[8]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, - uint64_t const& desc_b, - uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, - uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %13, 0;\n" - "wgmma.mma_async.sync.aligned.m64n16k32.s32.u8.u8 " - "{%0, %1, %2, %3, %4, %5, %6, %7}," - "{%8, %9, %10, %11}," - " %12," - " p;\n" - "}\n" - : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), - "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) - : "r"(a0), "r"(a1), "r"(a2), "r"(a3), - "l"(desc_b), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x16x32_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// GMMA 64x16x32 TN S32+=U8*U8 -struct MMA_64x16x32_S32U8U8_RS_TN_SATURATE -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[8]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, - uint64_t const& desc_b, - uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, - uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %13, 0;\n" - "wgmma.mma_async.sync.aligned.m64n16k32.s32.u8.u8.satfinite " - "{%0, %1, %2, %3, %4, %5, %6, %7}," - "{%8, %9, %10, %11}," - " %12," - " p;\n" - "}\n" - : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), - "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) - : "r"(a0), "r"(a1), "r"(a2), "r"(a3), - "l"(desc_b), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x16x32_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// GMMA 64x32x32 TN S32+=U8*U8 -struct MMA_64x32x32_S32U8U8_RS_TN -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[16]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %21, 0;\n" - "wgmma.mma_async.sync.aligned.m64n32k32.s32.u8.u8 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15}," - "{%16, %17, %18, %19}," - " %20," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), - "l"(desc_b), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x32x32_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// GMMA 64x32x32 TN S32+=U8*U8 -struct MMA_64x32x32_S32U8U8_RS_TN_SATURATE -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[16]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %21, 0;\n" - "wgmma.mma_async.sync.aligned.m64n32k32.s32.u8.u8.satfinite " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15}," - "{%16, %17, %18, %19}," - " %20," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), - "l"(desc_b), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x32x32_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x48x32 TN S32+=U8*U8 -struct MMA_64x48x32_S32U8U8_RS_TN -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[24]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %29, 0;\n" - "wgmma.mma_async.sync.aligned.m64n48k32.s32.u8.u8 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23}," - "{%24, %25, %26, %27}," - " %28," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), - "l"(desc_b), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x48x32_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x48x32 TN S32+=U8*U8 -struct MMA_64x48x32_S32U8U8_RS_TN_SATURATE -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[24]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %29, 0;\n" - "wgmma.mma_async.sync.aligned.m64n48k32.s32.u8.u8.satfinite " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23}," - "{%24, %25, %26, %27}," - " %28," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), - "l"(desc_b), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x48x32_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// GMMA 64x64x32 TN S32+=U8*U8 -struct MMA_64x64x32_S32U8U8_RS_TN -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[32]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %37, 0;\n" - "wgmma.mma_async.sync.aligned.m64n64k32.s32.u8.u8 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31}," - "{%32, %33, %34, %35}," - " %36," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), - "l"(desc_b), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x64x32_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// GMMA 64x64x32 TN S32+=U8*U8 -struct MMA_64x64x32_S32U8U8_RS_TN_SATURATE -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[32]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %37, 0;\n" - "wgmma.mma_async.sync.aligned.m64n64k32.s32.u8.u8.satfinite " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31}," - "{%32, %33, %34, %35}," - " %36," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), - "l"(desc_b), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x64x32_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x80x32 TN S32+=U8*U8 -struct MMA_64x80x32_S32U8U8_RS_TN -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[40]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %45, 0;\n" - "wgmma.mma_async.sync.aligned.m64n80k32.s32.u8.u8 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39}," - "{%40, %41, %42, %43}," - " %44," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), - "l"(desc_b), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x80x32_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x80x32 TN S32+=U8*U8 -struct MMA_64x80x32_S32U8U8_RS_TN_SATURATE -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[40]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %45, 0;\n" - "wgmma.mma_async.sync.aligned.m64n80k32.s32.u8.u8.satfinite " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39}," - "{%40, %41, %42, %43}," - " %44," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), - "l"(desc_b), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x80x32_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// GMMA 64x96x32 TN S32+=U8*U8 -struct MMA_64x96x32_S32U8U8_RS_TN -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[48]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %53, 0;\n" - "wgmma.mma_async.sync.aligned.m64n96k32.s32.u8.u8 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47}," - "{%48, %49, %50, %51}," - " %52," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), - "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), - "l"(desc_b), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x96x32_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// GMMA 64x96x32 TN S32+=U8*U8 -struct MMA_64x96x32_S32U8U8_RS_TN_SATURATE -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[48]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %53, 0;\n" - "wgmma.mma_async.sync.aligned.m64n96k32.s32.u8.u8.satfinite " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47}," - "{%48, %49, %50, %51}," - " %52," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), - "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), - "l"(desc_b), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x96x32_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x112x32 TN S32+=U8*U8 -struct MMA_64x112x32_S32U8U8_RS_TN -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[56]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, - uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, - uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %61, 0;\n" - "wgmma.mma_async.sync.aligned.m64n112k32.s32.u8.u8 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55}," - "{%56, %57, %58, %59}," - " %60," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), - "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), - "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), - "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), - "l"(desc_b), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x112x32_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x112x32 TN S32+=U8*U8 -struct MMA_64x112x32_S32U8U8_RS_TN_SATURATE -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[56]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, - uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, - uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %61, 0;\n" - "wgmma.mma_async.sync.aligned.m64n112k32.s32.u8.u8.satfinite " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55}," - "{%56, %57, %58, %59}," - " %60," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), - "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), - "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), - "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), - "l"(desc_b), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x112x32_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// GMMA 64x128x32 TN S32+=U8*U8 -struct MMA_64x128x32_S32U8U8_RS_TN -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[64]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, - uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, - uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, - uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, - uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %69, 0;\n" - "wgmma.mma_async.sync.aligned.m64n128k32.s32.u8.u8 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63}," - "{%64, %65, %66, %67}," - " %68," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), - "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), - "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), - "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), - "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), - "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), - "l"(desc_b), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x128x32_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// GMMA 64x128x32 TN S32+=U8*U8 -struct MMA_64x128x32_S32U8U8_RS_TN_SATURATE -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[64]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, - uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, - uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, - uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, - uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %69, 0;\n" - "wgmma.mma_async.sync.aligned.m64n128k32.s32.u8.u8.satfinite " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63}," - "{%64, %65, %66, %67}," - " %68," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), - "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), - "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), - "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), - "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), - "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), - "l"(desc_b), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x128x32_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x144x32 TN S32+=U8*U8 -struct MMA_64x144x32_S32U8U8_RS_TN -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[72]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, - uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, - uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, - uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, - uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, - uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, - uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %77, 0;\n" - "wgmma.mma_async.sync.aligned.m64n144k32.s32.u8.u8 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71}," - "{%72, %73, %74, %75}," - " %76," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), - "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), - "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), - "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), - "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), - "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), - "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), - "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), - "l"(desc_b), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x144x32_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x144x32 TN S32+=U8*U8 -struct MMA_64x144x32_S32U8U8_RS_TN_SATURATE -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[72]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, - uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, - uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, - uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, - uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, - uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, - uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %77, 0;\n" - "wgmma.mma_async.sync.aligned.m64n144k32.s32.u8.u8.satfinite " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71}," - "{%72, %73, %74, %75}," - " %76," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), - "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), - "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), - "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), - "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), - "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), - "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), - "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), - "l"(desc_b), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x144x32_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x160x32 TN S32+=U8*U8 -struct MMA_64x160x32_S32U8U8_RS_TN -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[80]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, - uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, - uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, - uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, - uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, - uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, - uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, - uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, - uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %85, 0;\n" - "wgmma.mma_async.sync.aligned.m64n160k32.s32.u8.u8 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79}," - "{%80, %81, %82, %83}," - " %84," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), - "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), - "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), - "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), - "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), - "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), - "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), - "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), - "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), - "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), - "l"(desc_b), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x160x32_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x160x32 TN S32+=U8*U8 -struct MMA_64x160x32_S32U8U8_RS_TN_SATURATE -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[80]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, - uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, - uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, - uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, - uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, - uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, - uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, - uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, - uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %85, 0;\n" - "wgmma.mma_async.sync.aligned.m64n160k32.s32.u8.u8.satfinite " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79}," - "{%80, %81, %82, %83}," - " %84," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), - "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), - "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), - "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), - "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), - "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), - "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), - "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), - "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), - "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), - "l"(desc_b), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x160x32_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x176x32 TN S32+=U8*U8 -struct MMA_64x176x32_S32U8U8_RS_TN -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[88]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, - uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, - uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, - uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, - uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, - uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, - uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, - uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, - uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, - uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, - uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %93, 0;\n" - "wgmma.mma_async.sync.aligned.m64n176k32.s32.u8.u8 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87}," - "{%88, %89, %90, %91}," - " %92," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), - "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), - "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), - "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), - "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), - "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), - "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), - "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), - "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), - "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), - "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), - "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), - "l"(desc_b), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x176x32_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x176x32 TN S32+=U8*U8 -struct MMA_64x176x32_S32U8U8_RS_TN_SATURATE -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[88]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, - uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, - uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, - uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, - uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, - uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, - uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, - uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, - uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, - uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, - uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %93, 0;\n" - "wgmma.mma_async.sync.aligned.m64n176k32.s32.u8.u8.satfinite " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87}," - "{%88, %89, %90, %91}," - " %92," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), - "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), - "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), - "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), - "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), - "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), - "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), - "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), - "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), - "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), - "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), - "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), - "l"(desc_b), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x176x32_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// GMMA 64x192x32 TN S32+=U8*U8 -struct MMA_64x192x32_S32U8U8_RS_TN -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[96]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, - uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, - uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, - uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, - uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, - uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, - uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, - uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, - uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, - uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, - uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, - uint32_t & d88, uint32_t & d89, uint32_t & d90, uint32_t & d91, - uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %101, 0;\n" - "wgmma.mma_async.sync.aligned.m64n192k32.s32.u8.u8 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87, " - " %88, %89, %90, %91, %92, %93, %94, %95}," - "{%96, %97, %98, %99}," - " %100," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), - "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), - "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), - "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), - "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), - "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), - "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), - "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), - "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), - "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), - "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), - "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87), - "+r"(d88), "+r"(d89), "+r"(d90), "+r"(d91), - "+r"(d92), "+r"(d93), "+r"(d94), "+r"(d95) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), - "l"(desc_b), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x192x32_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// GMMA 64x192x32 TN S32+=U8*U8 -struct MMA_64x192x32_S32U8U8_RS_TN_SATURATE -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[96]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, - uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, - uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, - uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, - uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, - uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, - uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, - uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, - uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, - uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, - uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, - uint32_t & d88, uint32_t & d89, uint32_t & d90, uint32_t & d91, - uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %101, 0;\n" - "wgmma.mma_async.sync.aligned.m64n192k32.s32.u8.u8.satfinite " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87, " - " %88, %89, %90, %91, %92, %93, %94, %95}," - "{%96, %97, %98, %99}," - " %100," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), - "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), - "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), - "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), - "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), - "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), - "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), - "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), - "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), - "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), - "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), - "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87), - "+r"(d88), "+r"(d89), "+r"(d90), "+r"(d91), - "+r"(d92), "+r"(d93), "+r"(d94), "+r"(d95) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), - "l"(desc_b), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x192x32_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x208x32 TN S32+=U8*U8 -struct MMA_64x208x32_S32U8U8_RS_TN -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[104]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, - uint64_t const& desc_b, - uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, - uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, - uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, - uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, - uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, - uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, - uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, - uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, - uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, - uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, - uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, - uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, - uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, - uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, - uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, - uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, - uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, - uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, - uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, - uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, - uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, - uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, - uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, - uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, - uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, - uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %109, 0;\n" - "wgmma.mma_async.sync.aligned.m64n208k32.s32.u8.u8 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87, " - " %88, %89, %90, %91, %92, %93, %94, %95, " - " %96, %97, %98, %99, %100, %101, %102, %103}," - "{%104, %105, %106, %107}," - " %108," - " p;\n" - "}\n" - : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), - "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), - "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), - "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), - "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), - "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), - "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), - "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), - "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), - "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), - "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), - "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), - "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), - "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), - "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), - "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), - "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), - "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), - "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), - "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), - "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), - "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), - "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), - "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), - "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), - "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103) - : "r"(a000), "r"(a001), "r"(a002), "r"(a003), - "l"(desc_b), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x208x32_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x208x32 TN S32+=U8*U8 -struct MMA_64x208x32_S32U8U8_RS_TN_SATURATE -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[104]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, - uint64_t const& desc_b, - uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, - uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, - uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, - uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, - uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, - uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, - uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, - uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, - uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, - uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, - uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, - uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, - uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, - uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, - uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, - uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, - uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, - uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, - uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, - uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, - uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, - uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, - uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, - uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, - uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, - uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %109, 0;\n" - "wgmma.mma_async.sync.aligned.m64n208k32.s32.u8.u8.satfinite " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87, " - " %88, %89, %90, %91, %92, %93, %94, %95, " - " %96, %97, %98, %99, %100, %101, %102, %103}," - "{%104, %105, %106, %107}," - " %108," - " p;\n" - "}\n" - : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), - "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), - "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), - "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), - "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), - "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), - "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), - "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), - "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), - "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), - "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), - "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), - "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), - "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), - "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), - "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), - "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), - "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), - "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), - "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), - "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), - "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), - "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), - "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), - "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), - "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103) - : "r"(a000), "r"(a001), "r"(a002), "r"(a003), - "l"(desc_b), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x208x32_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x224x32 TN S32+=U8*U8 -struct MMA_64x224x32_S32U8U8_RS_TN -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[112]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, - uint64_t const& desc_b, - uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, - uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, - uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, - uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, - uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, - uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, - uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, - uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, - uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, - uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, - uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, - uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, - uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, - uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, - uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, - uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, - uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, - uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, - uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, - uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, - uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, - uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, - uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, - uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, - uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, - uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, - uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, - uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %117, 0;\n" - "wgmma.mma_async.sync.aligned.m64n224k32.s32.u8.u8 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87, " - " %88, %89, %90, %91, %92, %93, %94, %95, " - " %96, %97, %98, %99, %100, %101, %102, %103, " - " %104, %105, %106, %107, %108, %109, %110, %111}," - "{%112, %113, %114, %115}," - " %116," - " p;\n" - "}\n" - : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), - "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), - "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), - "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), - "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), - "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), - "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), - "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), - "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), - "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), - "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), - "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), - "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), - "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), - "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), - "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), - "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), - "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), - "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), - "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), - "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), - "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), - "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), - "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), - "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), - "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), - "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), - "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111) - : "r"(a000), "r"(a001), "r"(a002), "r"(a003), - "l"(desc_b), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x224x32_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x224x32 TN S32+=U8*U8 -struct MMA_64x224x32_S32U8U8_RS_TN_SATURATE -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[112]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, - uint64_t const& desc_b, - uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, - uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, - uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, - uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, - uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, - uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, - uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, - uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, - uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, - uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, - uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, - uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, - uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, - uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, - uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, - uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, - uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, - uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, - uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, - uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, - uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, - uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, - uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, - uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, - uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, - uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, - uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, - uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %117, 0;\n" - "wgmma.mma_async.sync.aligned.m64n224k32.s32.u8.u8.satfinite " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87, " - " %88, %89, %90, %91, %92, %93, %94, %95, " - " %96, %97, %98, %99, %100, %101, %102, %103, " - " %104, %105, %106, %107, %108, %109, %110, %111}," - "{%112, %113, %114, %115}," - " %116," - " p;\n" - "}\n" - : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), - "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), - "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), - "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), - "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), - "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), - "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), - "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), - "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), - "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), - "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), - "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), - "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), - "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), - "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), - "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), - "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), - "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), - "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), - "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), - "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), - "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), - "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), - "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), - "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), - "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), - "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), - "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111) - : "r"(a000), "r"(a001), "r"(a002), "r"(a003), - "l"(desc_b), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x224x32_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x240x32 TN S32+=U8*U8 -struct MMA_64x240x32_S32U8U8_RS_TN -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[120]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, - uint64_t const& desc_b, - uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, - uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, - uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, - uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, - uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, - uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, - uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, - uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, - uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, - uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, - uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, - uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, - uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, - uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, - uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, - uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, - uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, - uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, - uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, - uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, - uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, - uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, - uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, - uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, - uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, - uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, - uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, - uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, - uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, - uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %125, 0;\n" - "wgmma.mma_async.sync.aligned.m64n240k32.s32.u8.u8 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87, " - " %88, %89, %90, %91, %92, %93, %94, %95, " - " %96, %97, %98, %99, %100, %101, %102, %103, " - " %104, %105, %106, %107, %108, %109, %110, %111, " - " %112, %113, %114, %115, %116, %117, %118, %119}," - "{%120, %121, %122, %123}," - " %124," - " p;\n" - "}\n" - : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), - "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), - "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), - "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), - "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), - "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), - "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), - "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), - "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), - "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), - "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), - "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), - "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), - "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), - "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), - "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), - "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), - "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), - "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), - "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), - "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), - "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), - "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), - "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), - "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), - "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), - "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), - "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), - "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), - "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119) - : "r"(a000), "r"(a001), "r"(a002), "r"(a003), - "l"(desc_b), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x240x32_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x240x32 TN S32+=U8*U8 -struct MMA_64x240x32_S32U8U8_RS_TN_SATURATE -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[120]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, - uint64_t const& desc_b, - uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, - uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, - uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, - uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, - uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, - uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, - uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, - uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, - uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, - uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, - uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, - uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, - uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, - uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, - uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, - uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, - uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, - uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, - uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, - uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, - uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, - uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, - uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, - uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, - uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, - uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, - uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, - uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, - uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, - uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %125, 0;\n" - "wgmma.mma_async.sync.aligned.m64n240k32.s32.u8.u8.satfinite " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87, " - " %88, %89, %90, %91, %92, %93, %94, %95, " - " %96, %97, %98, %99, %100, %101, %102, %103, " - " %104, %105, %106, %107, %108, %109, %110, %111, " - " %112, %113, %114, %115, %116, %117, %118, %119}," - "{%120, %121, %122, %123}," - " %124," - " p;\n" - "}\n" - : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), - "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), - "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), - "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), - "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), - "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), - "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), - "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), - "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), - "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), - "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), - "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), - "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), - "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), - "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), - "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), - "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), - "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), - "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), - "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), - "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), - "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), - "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), - "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), - "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), - "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), - "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), - "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), - "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), - "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119) - : "r"(a000), "r"(a001), "r"(a002), "r"(a003), - "l"(desc_b), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x240x32_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// GMMA 64x256x32 TN S32+=U8*U8 -struct MMA_64x256x32_S32U8U8_RS_TN -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[128]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, - uint64_t const& desc_b, - uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, - uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, - uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, - uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, - uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, - uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, - uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, - uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, - uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, - uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, - uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, - uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, - uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, - uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, - uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, - uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, - uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, - uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, - uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, - uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, - uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, - uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, - uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, - uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, - uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, - uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, - uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, - uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, - uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, - uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, - uint32_t & d120, uint32_t & d121, uint32_t & d122, uint32_t & d123, - uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %133, 0;\n" - "wgmma.mma_async.sync.aligned.m64n256k32.s32.u8.u8 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87, " - " %88, %89, %90, %91, %92, %93, %94, %95, " - " %96, %97, %98, %99, %100, %101, %102, %103, " - " %104, %105, %106, %107, %108, %109, %110, %111, " - " %112, %113, %114, %115, %116, %117, %118, %119, " - " %120, %121, %122, %123, %124, %125, %126, %127}," - "{%128, %129, %130, %131}," - " %132," - " p;\n" - "}\n" - : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), - "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), - "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), - "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), - "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), - "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), - "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), - "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), - "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), - "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), - "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), - "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), - "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), - "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), - "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), - "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), - "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), - "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), - "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), - "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), - "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), - "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), - "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), - "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), - "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), - "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), - "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), - "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), - "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), - "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119), - "+r"(d120), "+r"(d121), "+r"(d122), "+r"(d123), - "+r"(d124), "+r"(d125), "+r"(d126), "+r"(d127) - : "r"(a000), "r"(a001), "r"(a002), "r"(a003), - "l"(desc_b), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x256x32_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// GMMA 64x256x32 TN S32+=U8*U8 -struct MMA_64x256x32_S32U8U8_RS_TN_SATURATE -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[128]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, - uint64_t const& desc_b, - uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, - uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, - uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, - uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, - uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, - uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, - uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, - uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, - uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, - uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, - uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, - uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, - uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, - uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, - uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, - uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, - uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, - uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, - uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, - uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, - uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, - uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, - uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, - uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, - uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, - uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, - uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, - uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, - uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, - uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, - uint32_t & d120, uint32_t & d121, uint32_t & d122, uint32_t & d123, - uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %133, 0;\n" - "wgmma.mma_async.sync.aligned.m64n256k32.s32.u8.u8.satfinite " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87, " - " %88, %89, %90, %91, %92, %93, %94, %95, " - " %96, %97, %98, %99, %100, %101, %102, %103, " - " %104, %105, %106, %107, %108, %109, %110, %111, " - " %112, %113, %114, %115, %116, %117, %118, %119, " - " %120, %121, %122, %123, %124, %125, %126, %127}," - "{%128, %129, %130, %131}," - " %132," - " p;\n" - "}\n" - : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), - "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), - "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), - "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), - "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), - "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), - "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), - "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), - "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), - "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), - "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), - "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), - "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), - "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), - "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), - "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), - "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), - "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), - "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), - "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), - "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), - "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), - "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), - "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), - "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), - "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), - "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), - "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), - "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), - "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119), - "+r"(d120), "+r"(d121), "+r"(d122), "+r"(d123), - "+r"(d124), "+r"(d125), "+r"(d126), "+r"(d127) - : "r"(a000), "r"(a001), "r"(a002), "r"(a003), - "l"(desc_b), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x256x32_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// GMMA 64x8x32 TN F16+=E4M3*E4M3 -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -struct MMA_64x8x32_F16E4M3E4M3_SS_TN -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[2]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - uint32_t & d0, uint32_t & d1, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %4, 0;\n" - "wgmma.mma_async.sync.aligned.m64n8k32.f16.e4m3.e4m3 " - "{%0, %1}," - " %2," - " %3," - " p, %5, %6;\n" - "}\n" - : "+r"(d0), "+r"(d1) - : "l"(desc_a), - "l"(desc_b), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x8x32_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// GMMA 64x8x32 TN F16+=E4M3*E4M3 -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -struct MMA_64x8x32_F16E4M3E4M3_RS_TN -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[2]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, - uint64_t const& desc_b, - uint32_t & d0, uint32_t & d1, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %7, 0;\n" - "wgmma.mma_async.sync.aligned.m64n8k32.f16.e4m3.e4m3 " - "{%0, %1}," - "{%2, %3, %4, %5}," - " %6," - " p, %8, %9;\n" - "}\n" - : "+r"(d0), "+r"(d1) - : "r"(a0), "r"(a1), "r"(a2), "r"(a3), - "l"(desc_b), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x8x32_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// GMMA 64x8x32 TN F32+=E4M3*E4M3 -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -struct MMA_64x8x32_F32E4M3E4M3_SS_TN -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = float[4]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - float & d0, float & d1, float & d2, float & d3, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %6, 0;\n" - "wgmma.mma_async.sync.aligned.m64n8k32.f32.e4m3.e4m3 " - "{%0, %1, %2, %3}," - " %4," - " %5," - " p, %7, %8;\n" - "}\n" - : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3) - : "l"(desc_a), - "l"(desc_b), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x8x32_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// GMMA 64x8x32 TN F32+=E4M3*E4M3 -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -struct MMA_64x8x32_F32E4M3E4M3_RS_TN -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using BRegisters = uint64_t[1]; - using CRegisters = float[4]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, - uint64_t const& desc_b, - float & d0, float & d1, float & d2, float & d3, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %9, 0;\n" - "wgmma.mma_async.sync.aligned.m64n8k32.f32.e4m3.e4m3 " - "{%0, %1, %2, %3}," - "{%4, %5, %6, %7}," - " %8," - " p, %10, %11;\n" - "}\n" - : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3) - : "r"(a0), "r"(a1), "r"(a2), "r"(a3), - "l"(desc_b), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x8x32_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// GMMA 64x16x32 TN F16+=E4M3*E4M3 -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -struct MMA_64x16x32_F16E4M3E4M3_SS_TN -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[4]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %6, 0;\n" - "wgmma.mma_async.sync.aligned.m64n16k32.f16.e4m3.e4m3 " - "{%0, %1, %2, %3}," - " %4," - " %5," - " p, %7, %8;\n" - "}\n" - : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) - : "l"(desc_a), - "l"(desc_b), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x16x32_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// GMMA 64x16x32 TN F16+=E4M3*E4M3 -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -struct MMA_64x16x32_F16E4M3E4M3_RS_TN -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[4]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, - uint64_t const& desc_b, - uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %9, 0;\n" - "wgmma.mma_async.sync.aligned.m64n16k32.f16.e4m3.e4m3 " - "{%0, %1, %2, %3}," - "{%4, %5, %6, %7}," - " %8," - " p, %10, %11;\n" - "}\n" - : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) - : "r"(a0), "r"(a1), "r"(a2), "r"(a3), - "l"(desc_b), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x16x32_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// GMMA 64x16x32 TN F32+=E4M3*E4M3 -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -struct MMA_64x16x32_F32E4M3E4M3_SS_TN -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = float[8]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - float & d0, float & d1, float & d2, float & d3, - float & d4, float & d5, float & d6, float & d7, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %10, 0;\n" - "wgmma.mma_async.sync.aligned.m64n16k32.f32.e4m3.e4m3 " - "{%0, %1, %2, %3, %4, %5, %6, %7}," - " %8," - " %9," - " p, %11, %12;\n" - "}\n" - : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3), - "+f"(d4), "+f"(d5), "+f"(d6), "+f"(d7) - : "l"(desc_a), - "l"(desc_b), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x16x32_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// GMMA 64x16x32 TN F32+=E4M3*E4M3 -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -struct MMA_64x16x32_F32E4M3E4M3_RS_TN -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using BRegisters = uint64_t[1]; - using CRegisters = float[8]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, - uint64_t const& desc_b, - float & d0, float & d1, float & d2, float & d3, - float & d4, float & d5, float & d6, float & d7, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %13, 0;\n" - "wgmma.mma_async.sync.aligned.m64n16k32.f32.e4m3.e4m3 " - "{%0, %1, %2, %3, %4, %5, %6, %7}," - "{%8, %9, %10, %11}," - " %12," - " p, %14, %15;\n" - "}\n" - : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3), - "+f"(d4), "+f"(d5), "+f"(d6), "+f"(d7) - : "r"(a0), "r"(a1), "r"(a2), "r"(a3), - "l"(desc_b), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x16x32_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// GMMA 64x32x32 TN F16+=E4M3*E4M3 -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -struct MMA_64x32x32_F16E4M3E4M3_SS_TN -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[8]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, - uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %10, 0;\n" - "wgmma.mma_async.sync.aligned.m64n32k32.f16.e4m3.e4m3 " - "{%0, %1, %2, %3, %4, %5, %6, %7}," - " %8," - " %9," - " p, %11, %12;\n" - "}\n" - : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), - "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) - : "l"(desc_a), - "l"(desc_b), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x32x32_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// GMMA 64x32x32 TN F16+=E4M3*E4M3 -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -struct MMA_64x32x32_F16E4M3E4M3_RS_TN -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[8]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, - uint64_t const& desc_b, - uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, - uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %13, 0;\n" - "wgmma.mma_async.sync.aligned.m64n32k32.f16.e4m3.e4m3 " - "{%0, %1, %2, %3, %4, %5, %6, %7}," - "{%8, %9, %10, %11}," - " %12," - " p, %14, %15;\n" - "}\n" - : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), - "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) - : "r"(a0), "r"(a1), "r"(a2), "r"(a3), - "l"(desc_b), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x32x32_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// GMMA 64x32x32 TN F32+=E4M3*E4M3 -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -struct MMA_64x32x32_F32E4M3E4M3_SS_TN -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = float[16]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - float & d00, float & d01, float & d02, float & d03, - float & d04, float & d05, float & d06, float & d07, - float & d08, float & d09, float & d10, float & d11, - float & d12, float & d13, float & d14, float & d15, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %18, 0;\n" - "wgmma.mma_async.sync.aligned.m64n32k32.f32.e4m3.e4m3 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15}," - " %16," - " %17," - " p, %19, %20;\n" - "}\n" - : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), - "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), - "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), - "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15) - : "l"(desc_a), - "l"(desc_b), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x32x32_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// GMMA 64x32x32 TN F32+=E4M3*E4M3 -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -struct MMA_64x32x32_F32E4M3E4M3_RS_TN -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using BRegisters = uint64_t[1]; - using CRegisters = float[16]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, - uint64_t const& desc_b, - float & d00, float & d01, float & d02, float & d03, - float & d04, float & d05, float & d06, float & d07, - float & d08, float & d09, float & d10, float & d11, - float & d12, float & d13, float & d14, float & d15, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %21, 0;\n" - "wgmma.mma_async.sync.aligned.m64n32k32.f32.e4m3.e4m3 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15}," - "{%16, %17, %18, %19}," - " %20," - " p, %22, %23;\n" - "}\n" - : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), - "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), - "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), - "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), - "l"(desc_b), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x32x32_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x48x32 TN F16+=E4M3*E4M3 -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -struct MMA_64x48x32_F16E4M3E4M3_SS_TN -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[12]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %14, 0;\n" - "wgmma.mma_async.sync.aligned.m64n48k32.f16.e4m3.e4m3 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11}," - " %12," - " %13," - " p, %15, %16;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11) - : "l"(desc_a), - "l"(desc_b), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x48x32_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x48x32 TN F16+=E4M3*E4M3 -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -struct MMA_64x48x32_F16E4M3E4M3_RS_TN -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[12]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %17, 0;\n" - "wgmma.mma_async.sync.aligned.m64n48k32.f16.e4m3.e4m3 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11}," - "{%12, %13, %14, %15}," - " %16," - " p, %18, %19;\n" + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + " %48," + " %49," + " p;\n" "}\n" : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), - "l"(desc_b), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x48x32_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x48x32 TN F32+=E4M3*E4M3 -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -struct MMA_64x48x32_F32E4M3E4M3_SS_TN -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = float[24]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - float & d00, float & d01, float & d02, float & d03, - float & d04, float & d05, float & d06, float & d07, - float & d08, float & d09, float & d10, float & d11, - float & d12, float & d13, float & d14, float & d15, - float & d16, float & d17, float & d18, float & d19, - float & d20, float & d21, float & d22, float & d23, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %26, 0;\n" - "wgmma.mma_async.sync.aligned.m64n48k32.f32.e4m3.e4m3 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23}," - " %24," - " %25," - " p, %27, %28;\n" - "}\n" - : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), - "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), - "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), - "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), - "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), - "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23) + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) : "l"(desc_a), "l"(desc_b), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x48x32_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x48x32 TN F32+=E4M3*E4M3 -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -struct MMA_64x48x32_F32E4M3E4M3_RS_TN -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using BRegisters = uint64_t[1]; - using CRegisters = float[24]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, - uint64_t const& desc_b, - float & d00, float & d01, float & d02, float & d03, - float & d04, float & d05, float & d06, float & d07, - float & d08, float & d09, float & d10, float & d11, - float & d12, float & d13, float & d14, float & d15, - float & d16, float & d17, float & d18, float & d19, - float & d20, float & d21, float & d22, float & d23, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %29, 0;\n" - "wgmma.mma_async.sync.aligned.m64n48k32.f32.e4m3.e4m3 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23}," - "{%24, %25, %26, %27}," - " %28," - " p, %30, %31;\n" - "}\n" - : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), - "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), - "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), - "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), - "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), - "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), - "l"(desc_b), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); + "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x48x32_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x96x32_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -// GMMA 64x64x32 TN F16+=E4M3*E4M3 -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -struct MMA_64x64x32_F16E4M3E4M3_SS_TN +// GMMA 64x128x32 TN S32+=S8*S8 +struct MMA_64x128x32_S32S8S8_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[16]; + using CRegisters = uint32_t[64]; CUTE_HOST_DEVICE static void fma(uint64_t const& desc_a, @@ -32327,108 +5034,93 @@ struct MMA_64x64x32_F16E4M3E4M3_SS_TN uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %18, 0;\n" - "wgmma.mma_async.sync.aligned.m64n64k32.f16.e4m3.e4m3 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15}," - " %16," - " %17," - " p, %19, %20;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) - : "l"(desc_a), - "l"(desc_b), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x64x32_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// GMMA 64x64x32 TN F16+=E4M3*E4M3 -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -struct MMA_64x64x32_F16E4M3E4M3_RS_TN -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[16]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %21, 0;\n" - "wgmma.mma_async.sync.aligned.m64n64k32.f16.e4m3.e4m3 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15}," - "{%16, %17, %18, %19}," - " %20," - " p, %22, %23;\n" + "setp.ne.b32 p, %66, 0;\n" + "wgmma.mma_async.sync.aligned.m64n128k32.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + " %64," + " %65," + " p;\n" "}\n" : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) + : "l"(desc_a), "l"(desc_b), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); + "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x64x32_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x128x32_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// GMMA 64x64x32 TN F32+=E4M3*E4M3 -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -struct MMA_64x64x32_F32E4M3E4M3_SS_TN +// GMMA 64x128x32 TN S32+=S8*S8 +struct MMA_64x128x32_S32S8S8_SS_TN_SATURATE { using DRegisters = void; using ARegisters = uint64_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = float[32]; + using CRegisters = uint32_t[64]; CUTE_HOST_DEVICE static void fma(uint64_t const& desc_a, uint64_t const& desc_b, - float & d00, float & d01, float & d02, float & d03, - float & d04, float & d05, float & d06, float & d07, - float & d08, float & d09, float & d10, float & d11, - float & d12, float & d13, float & d14, float & d15, - float & d16, float & d17, float & d18, float & d19, - float & d20, float & d21, float & d22, float & d23, - float & d24, float & d25, float & d26, float & d27, - float & d28, float & d29, float & d30, float & d31, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) @@ -32436,106 +5128,149 @@ struct MMA_64x64x32_F32E4M3E4M3_SS_TN asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %34, 0;\n" - "wgmma.mma_async.sync.aligned.m64n64k32.f32.e4m3.e4m3 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31}," - " %32," - " %33," - " p, %35, %36;\n" + "setp.ne.b32 p, %66, 0;\n" + "wgmma.mma_async.sync.aligned.m64n128k32.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + " %64," + " %65," + " p;\n" "}\n" - : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), - "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), - "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), - "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), - "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), - "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), - "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), - "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31) + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) : "l"(desc_a), "l"(desc_b), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); + "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x64x32_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x128x32_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// GMMA 64x64x32 TN F32+=E4M3*E4M3 -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -struct MMA_64x64x32_F32E4M3E4M3_RS_TN +// GMMA 64x192x32 TN S32+=S8*S8 +struct MMA_64x192x32_S32S8S8_SS_TN { using DRegisters = void; - using ARegisters = uint32_t[4]; + using ARegisters = uint64_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = float[32]; + using CRegisters = uint32_t[96]; CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + fma(uint64_t const& desc_a, uint64_t const& desc_b, - float & d00, float & d01, float & d02, float & d03, - float & d04, float & d05, float & d06, float & d07, - float & d08, float & d09, float & d10, float & d11, - float & d12, float & d13, float & d14, float & d15, - float & d16, float & d17, float & d18, float & d19, - float & d20, float & d21, float & d22, float & d23, - float & d24, float & d25, float & d26, float & d27, - float & d28, float & d29, float & d30, float & d31, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + uint32_t & d88, uint32_t & d89, uint32_t & d90, uint32_t & d91, + uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %37, 0;\n" - "wgmma.mma_async.sync.aligned.m64n64k32.f32.e4m3.e4m3 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31}," - "{%32, %33, %34, %35}," - " %36," - " p, %38, %39;\n" + "setp.ne.b32 p, %98, 0;\n" + "wgmma.mma_async.sync.aligned.m64n192k32.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + " %96," + " %97," + " p;\n" "}\n" - : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), - "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), - "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), - "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), - "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), - "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), - "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), - "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87), + "+r"(d88), "+r"(d89), "+r"(d90), "+r"(d91), + "+r"(d92), "+r"(d93), "+r"(d94), "+r"(d95) + : "l"(desc_a), "l"(desc_b), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); + "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x64x32_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x192x32_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x80x32 TN F16+=E4M3*E4M3 -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -struct MMA_64x80x32_F16E4M3E4M3_SS_TN +// GMMA 64x192x32 TN S32+=S8*S8 +struct MMA_64x192x32_S32S8S8_SS_TN_SATURATE { using DRegisters = void; using ARegisters = uint64_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[20]; + using CRegisters = uint32_t[96]; CUTE_HOST_DEVICE static void fma(uint64_t const& desc_a, @@ -32545,6 +5280,25 @@ struct MMA_64x80x32_F16E4M3E4M3_SS_TN uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + uint32_t & d88, uint32_t & d89, uint32_t & d90, uint32_t & d91, + uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) @@ -32552,112 +5306,217 @@ struct MMA_64x80x32_F16E4M3E4M3_SS_TN asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %22, 0;\n" - "wgmma.mma_async.sync.aligned.m64n80k32.f16.e4m3.e4m3 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19}," - " %20," - " %21," - " p, %23, %24;\n" + "setp.ne.b32 p, %98, 0;\n" + "wgmma.mma_async.sync.aligned.m64n192k32.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + " %96," + " %97," + " p;\n" "}\n" : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19) + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87), + "+r"(d88), "+r"(d89), "+r"(d90), "+r"(d91), + "+r"(d92), "+r"(d93), "+r"(d94), "+r"(d95) : "l"(desc_a), "l"(desc_b), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); + "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x80x32_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x192x32_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x80x32 TN F16+=E4M3*E4M3 -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -struct MMA_64x80x32_F16E4M3E4M3_RS_TN +// GMMA 64x256x32 TN S32+=S8*S8 +struct MMA_64x256x32_S32S8S8_SS_TN { using DRegisters = void; - using ARegisters = uint32_t[4]; + using ARegisters = uint64_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[20]; + using CRegisters = uint32_t[128]; CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + fma(uint64_t const& desc_a, uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + uint32_t & d120, uint32_t & d121, uint32_t & d122, uint32_t & d123, + uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %25, 0;\n" - "wgmma.mma_async.sync.aligned.m64n80k32.f16.e4m3.e4m3 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19}," - "{%20, %21, %22, %23}," - " %24," - " p, %26, %27;\n" + "setp.ne.b32 p, %130, 0;\n" + "wgmma.mma_async.sync.aligned.m64n256k32.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + " %128," + " %129," + " p;\n" "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119), + "+r"(d120), "+r"(d121), "+r"(d122), "+r"(d123), + "+r"(d124), "+r"(d125), "+r"(d126), "+r"(d127) + : "l"(desc_a), "l"(desc_b), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); + "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x80x32_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x256x32_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x80x32 TN F32+=E4M3*E4M3 -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -struct MMA_64x80x32_F32E4M3E4M3_SS_TN +// GMMA 64x256x32 TN S32+=S8*S8 +struct MMA_64x256x32_S32S8S8_SS_TN_SATURATE { using DRegisters = void; using ARegisters = uint64_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = float[40]; + using CRegisters = uint32_t[128]; CUTE_HOST_DEVICE static void fma(uint64_t const& desc_a, uint64_t const& desc_b, - float & d00, float & d01, float & d02, float & d03, - float & d04, float & d05, float & d06, float & d07, - float & d08, float & d09, float & d10, float & d11, - float & d12, float & d13, float & d14, float & d15, - float & d16, float & d17, float & d18, float & d19, - float & d20, float & d21, float & d22, float & d23, - float & d24, float & d25, float & d26, float & d27, - float & d28, float & d29, float & d30, float & d31, - float & d32, float & d33, float & d34, float & d35, - float & d36, float & d37, float & d38, float & d39, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + uint32_t & d120, uint32_t & d121, uint32_t & d122, uint32_t & d123, + uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) @@ -32665,65 +5524,83 @@ struct MMA_64x80x32_F32E4M3E4M3_SS_TN asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %42, 0;\n" - "wgmma.mma_async.sync.aligned.m64n80k32.f32.e4m3.e4m3 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39}," - " %40," - " %41," - " p, %43, %44;\n" + "setp.ne.b32 p, %130, 0;\n" + "wgmma.mma_async.sync.aligned.m64n256k32.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + " %128," + " %129," + " p;\n" "}\n" - : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), - "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), - "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), - "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), - "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), - "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), - "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), - "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), - "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), - "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39) + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119), + "+r"(d120), "+r"(d121), "+r"(d122), "+r"(d123), + "+r"(d124), "+r"(d125), "+r"(d126), "+r"(d127) : "l"(desc_a), "l"(desc_b), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); + "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x80x32_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x256x32_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x80x32 TN F32+=E4M3*E4M3 -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -struct MMA_64x80x32_F32E4M3E4M3_RS_TN +// GMMA 64x8x32 TN S32+=S8*S8 +struct MMA_64x8x32_S32S8S8_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; using BRegisters = uint64_t[1]; - using CRegisters = float[40]; + using CRegisters = uint32_t[4]; CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, uint64_t const& desc_b, - float & d00, float & d01, float & d02, float & d03, - float & d04, float & d05, float & d06, float & d07, - float & d08, float & d09, float & d10, float & d11, - float & d12, float & d13, float & d14, float & d15, - float & d16, float & d17, float & d18, float & d19, - float & d20, float & d21, float & d22, float & d23, - float & d24, float & d25, float & d26, float & d27, - float & d28, float & d29, float & d30, float & d31, - float & d32, float & d33, float & d34, float & d35, - float & d36, float & d37, float & d38, float & d39, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) @@ -32731,114 +5608,76 @@ struct MMA_64x80x32_F32E4M3E4M3_RS_TN asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %45, 0;\n" - "wgmma.mma_async.sync.aligned.m64n80k32.f32.e4m3.e4m3 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39}," - "{%40, %41, %42, %43}," - " %44," - " p, %46, %47;\n" + "setp.ne.b32 p, %9, 0;\n" + "wgmma.mma_async.sync.aligned.m64n8k32.s32.s8.s8 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + " %8," + " p;\n" "}\n" - : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), - "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), - "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), - "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), - "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), - "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), - "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), - "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), - "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), - "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), "l"(desc_b), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); + "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x80x32_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x8x32_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -// GMMA 64x96x32 TN F16+=E4M3*E4M3 -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -struct MMA_64x96x32_F16E4M3E4M3_SS_TN +// GMMA 64x8x32 TN S32+=S8*S8 +struct MMA_64x8x32_S32S8S8_RS_TN_SATURATE { using DRegisters = void; - using ARegisters = uint64_t[1]; + using ARegisters = uint32_t[4]; using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[24]; + using CRegisters = uint32_t[4]; CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %26, 0;\n" - "wgmma.mma_async.sync.aligned.m64n96k32.f16.e4m3.e4m3 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23}," - " %24," - " %25," - " p, %27, %28;\n" + "setp.ne.b32 p, %9, 0;\n" + "wgmma.mma_async.sync.aligned.m64n8k32.s32.s8.s8.satfinite " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + " %8," + " p;\n" "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) - : "l"(desc_a), + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), "l"(desc_b), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); + "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x96x32_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x8x32_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// GMMA 64x96x32 TN F16+=E4M3*E4M3 -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -struct MMA_64x96x32_F16E4M3E4M3_RS_TN +// GMMA 64x16x32 TN S32+=S8*S8 +struct MMA_64x16x32_S32S8S8_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[24]; + using CRegisters = uint32_t[8]; CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) @@ -32846,128 +5685,81 @@ struct MMA_64x96x32_F16E4M3E4M3_RS_TN asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %29, 0;\n" - "wgmma.mma_async.sync.aligned.m64n96k32.f16.e4m3.e4m3 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23}," - "{%24, %25, %26, %27}," - " %28," - " p, %30, %31;\n" + "setp.ne.b32 p, %13, 0;\n" + "wgmma.mma_async.sync.aligned.m64n16k32.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + "{%8, %9, %10, %11}," + " %12," + " p;\n" "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), "l"(desc_b), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); + "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x96x32_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x16x32_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// GMMA 64x96x32 TN F32+=E4M3*E4M3 -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -struct MMA_64x96x32_F32E4M3E4M3_SS_TN +// GMMA 64x16x32 TN S32+=S8*S8 +struct MMA_64x16x32_S32S8S8_RS_TN_SATURATE { using DRegisters = void; - using ARegisters = uint64_t[1]; + using ARegisters = uint32_t[4]; using BRegisters = uint64_t[1]; - using CRegisters = float[48]; + using CRegisters = uint32_t[8]; CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, uint64_t const& desc_b, - float & d00, float & d01, float & d02, float & d03, - float & d04, float & d05, float & d06, float & d07, - float & d08, float & d09, float & d10, float & d11, - float & d12, float & d13, float & d14, float & d15, - float & d16, float & d17, float & d18, float & d19, - float & d20, float & d21, float & d22, float & d23, - float & d24, float & d25, float & d26, float & d27, - float & d28, float & d29, float & d30, float & d31, - float & d32, float & d33, float & d34, float & d35, - float & d36, float & d37, float & d38, float & d39, - float & d40, float & d41, float & d42, float & d43, - float & d44, float & d45, float & d46, float & d47, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %50, 0;\n" - "wgmma.mma_async.sync.aligned.m64n96k32.f32.e4m3.e4m3 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47}," - " %48," - " %49," - " p, %51, %52;\n" + "setp.ne.b32 p, %13, 0;\n" + "wgmma.mma_async.sync.aligned.m64n16k32.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + "{%8, %9, %10, %11}," + " %12," + " p;\n" "}\n" - : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), - "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), - "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), - "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), - "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), - "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), - "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), - "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), - "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), - "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), - "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), - "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47) - : "l"(desc_a), + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), "l"(desc_b), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); + "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x96x32_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x16x32_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// GMMA 64x96x32 TN F32+=E4M3*E4M3 -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -struct MMA_64x96x32_F32E4M3E4M3_RS_TN +// GMMA 64x32x32 TN S32+=S8*S8 +struct MMA_64x32x32_S32S8S8_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; using BRegisters = uint64_t[1]; - using CRegisters = float[48]; + using CRegisters = uint32_t[16]; CUTE_HOST_DEVICE static void fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, uint64_t const& desc_b, - float & d00, float & d01, float & d02, float & d03, - float & d04, float & d05, float & d06, float & d07, - float & d08, float & d09, float & d10, float & d11, - float & d12, float & d13, float & d14, float & d15, - float & d16, float & d17, float & d18, float & d19, - float & d20, float & d21, float & d22, float & d23, - float & d24, float & d25, float & d26, float & d27, - float & d28, float & d29, float & d30, float & d31, - float & d32, float & d33, float & d34, float & d35, - float & d36, float & d37, float & d38, float & d39, - float & d40, float & d41, float & d42, float & d43, - float & d44, float & d45, float & d46, float & d47, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) @@ -32975,112 +5767,81 @@ struct MMA_64x96x32_F32E4M3E4M3_RS_TN asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %53, 0;\n" - "wgmma.mma_async.sync.aligned.m64n96k32.f32.e4m3.e4m3 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47}," - "{%48, %49, %50, %51}," - " %52," - " p, %54, %55;\n" + "setp.ne.b32 p, %21, 0;\n" + "wgmma.mma_async.sync.aligned.m64n32k32.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + "{%16, %17, %18, %19}," + " %20," + " p;\n" "}\n" - : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), - "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), - "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), - "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), - "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), - "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), - "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), - "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), - "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), - "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), - "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), - "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47) + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) : "r"(a00), "r"(a01), "r"(a02), "r"(a03), "l"(desc_b), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); + "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x96x32_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x32x32_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x112x32 TN F16+=E4M3*E4M3 -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -struct MMA_64x112x32_F16E4M3E4M3_SS_TN +// GMMA 64x32x32 TN S32+=S8*S8 +struct MMA_64x32x32_S32S8S8_RS_TN_SATURATE { using DRegisters = void; - using ARegisters = uint64_t[1]; + using ARegisters = uint32_t[4]; using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[28]; + using CRegisters = uint32_t[16]; CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, uint64_t const& desc_b, uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %30, 0;\n" - "wgmma.mma_async.sync.aligned.m64n112k32.f16.e4m3.e4m3 " + "setp.ne.b32 p, %21, 0;\n" + "wgmma.mma_async.sync.aligned.m64n32k32.s32.s8.s8.satfinite " "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27}," - " %28," - " %29," - " p, %31, %32;\n" + " %8, %9, %10, %11, %12, %13, %14, %15}," + "{%16, %17, %18, %19}," + " %20," + " p;\n" "}\n" : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27) - : "l"(desc_a), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), "l"(desc_b), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); + "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x112x32_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x32x32_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x112x32 TN F16+=E4M3*E4M3 -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -struct MMA_64x112x32_F16E4M3E4M3_RS_TN +// GMMA 64x64x32 TN S32+=S8*S8 +struct MMA_64x64x32_S32S8S8_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[28]; + using CRegisters = uint32_t[32]; CUTE_HOST_DEVICE static void fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, @@ -33092,6 +5853,7 @@ struct MMA_64x112x32_F16E4M3E4M3_RS_TN uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) @@ -33099,15 +5861,15 @@ struct MMA_64x112x32_F16E4M3E4M3_RS_TN asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %33, 0;\n" - "wgmma.mma_async.sync.aligned.m64n112k32.f16.e4m3.e4m3 " + "setp.ne.b32 p, %37, 0;\n" + "wgmma.mma_async.sync.aligned.m64n64k32.s32.s8.s8 " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27}," - "{%28, %29, %30, %31}," - " %32," - " p, %34, %35;\n" + " %24, %25, %26, %27, %28, %29, %30, %31}," + "{%32, %33, %34, %35}," + " %36," + " p;\n" "}\n" : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), @@ -33115,125 +5877,38 @@ struct MMA_64x112x32_F16E4M3E4M3_RS_TN "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27) + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) : "r"(a00), "r"(a01), "r"(a02), "r"(a03), "l"(desc_b), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x112x32_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x112x32 TN F32+=E4M3*E4M3 -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -struct MMA_64x112x32_F32E4M3E4M3_SS_TN -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = float[56]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - float & d00, float & d01, float & d02, float & d03, - float & d04, float & d05, float & d06, float & d07, - float & d08, float & d09, float & d10, float & d11, - float & d12, float & d13, float & d14, float & d15, - float & d16, float & d17, float & d18, float & d19, - float & d20, float & d21, float & d22, float & d23, - float & d24, float & d25, float & d26, float & d27, - float & d28, float & d29, float & d30, float & d31, - float & d32, float & d33, float & d34, float & d35, - float & d36, float & d37, float & d38, float & d39, - float & d40, float & d41, float & d42, float & d43, - float & d44, float & d45, float & d46, float & d47, - float & d48, float & d49, float & d50, float & d51, - float & d52, float & d53, float & d54, float & d55, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %58, 0;\n" - "wgmma.mma_async.sync.aligned.m64n112k32.f32.e4m3.e4m3 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55}," - " %56," - " %57," - " p, %59, %60;\n" - "}\n" - : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), - "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), - "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), - "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), - "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), - "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), - "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), - "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), - "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), - "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), - "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), - "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), - "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), - "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55) - : "l"(desc_a), - "l"(desc_b), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); + "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x112x32_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x64x32_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x112x32 TN F32+=E4M3*E4M3 -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -struct MMA_64x112x32_F32E4M3E4M3_RS_TN +// GMMA 64x64x32 TN S32+=S8*S8 +struct MMA_64x64x32_S32S8S8_RS_TN_SATURATE { using DRegisters = void; using ARegisters = uint32_t[4]; using BRegisters = uint64_t[1]; - using CRegisters = float[56]; + using CRegisters = uint32_t[32]; CUTE_HOST_DEVICE static void fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, uint64_t const& desc_b, - float & d00, float & d01, float & d02, float & d03, - float & d04, float & d05, float & d06, float & d07, - float & d08, float & d09, float & d10, float & d11, - float & d12, float & d13, float & d14, float & d15, - float & d16, float & d17, float & d18, float & d19, - float & d20, float & d21, float & d22, float & d23, - float & d24, float & d25, float & d26, float & d27, - float & d28, float & d29, float & d30, float & d31, - float & d32, float & d33, float & d34, float & d35, - float & d36, float & d37, float & d38, float & d39, - float & d40, float & d41, float & d42, float & d43, - float & d44, float & d45, float & d46, float & d47, - float & d48, float & d49, float & d50, float & d51, - float & d52, float & d53, float & d54, float & d55, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) @@ -33241,59 +5916,45 @@ struct MMA_64x112x32_F32E4M3E4M3_RS_TN asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %61, 0;\n" - "wgmma.mma_async.sync.aligned.m64n112k32.f32.e4m3.e4m3 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55}," - "{%56, %57, %58, %59}," - " %60," - " p, %62, %63;\n" + "setp.ne.b32 p, %37, 0;\n" + "wgmma.mma_async.sync.aligned.m64n64k32.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + "{%32, %33, %34, %35}," + " %36," + " p;\n" "}\n" - : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), - "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), - "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), - "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), - "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), - "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), - "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), - "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), - "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), - "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), - "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), - "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), - "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), - "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55) + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) : "r"(a00), "r"(a01), "r"(a02), "r"(a03), "l"(desc_b), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); + "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x112x32_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x64x32_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// GMMA 64x128x32 TN F16+=E4M3*E4M3 -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -struct MMA_64x128x32_F16E4M3E4M3_SS_TN +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x96x32 TN S32+=S8*S8 +struct MMA_64x96x32_S32S8S8_RS_TN { using DRegisters = void; - using ARegisters = uint64_t[1]; + using ARegisters = uint32_t[4]; using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[32]; + using CRegisters = uint32_t[48]; CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, uint64_t const& desc_b, uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, @@ -33303,22 +5964,28 @@ struct MMA_64x128x32_F16E4M3E4M3_SS_TN uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %34, 0;\n" - "wgmma.mma_async.sync.aligned.m64n128k32.f16.e4m3.e4m3 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31}," - " %32," - " %33," - " p, %35, %36;\n" + "setp.ne.b32 p, %53, 0;\n" + "wgmma.mma_async.sync.aligned.m64n96k32.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + "{%48, %49, %50, %51}," + " %52," + " p;\n" "}\n" : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), @@ -33327,29 +5994,29 @@ struct MMA_64x128x32_F16E4M3E4M3_SS_TN "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) - : "l"(desc_a), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), "l"(desc_b), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); + "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x128x32_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x96x32_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// GMMA 64x128x32 TN F16+=E4M3*E4M3 -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -struct MMA_64x128x32_F16E4M3E4M3_RS_TN +// GMMA 64x96x32 TN S32+=S8*S8 +struct MMA_64x96x32_S32S8S8_RS_TN_SATURATE { using DRegisters = void; using ARegisters = uint32_t[4]; using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[32]; + using CRegisters = uint32_t[48]; CUTE_HOST_DEVICE static void fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, @@ -33362,6 +6029,10 @@ struct MMA_64x128x32_F16E4M3E4M3_RS_TN uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) @@ -33369,15 +6040,17 @@ struct MMA_64x128x32_F16E4M3E4M3_RS_TN asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %37, 0;\n" - "wgmma.mma_async.sync.aligned.m64n128k32.f16.e4m3.e4m3 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31}," - "{%32, %33, %34, %35}," - " %36," - " p, %38, %39;\n" + "setp.ne.b32 p, %53, 0;\n" + "wgmma.mma_async.sync.aligned.m64n96k32.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + "{%48, %49, %50, %51}," + " %52," + " p;\n" "}\n" : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), @@ -33386,58 +6059,58 @@ struct MMA_64x128x32_F16E4M3E4M3_RS_TN "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) : "r"(a00), "r"(a01), "r"(a02), "r"(a03), "l"(desc_b), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); + "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x128x32_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x96x32_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// GMMA 64x128x32 TN F32+=E4M3*E4M3 -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -struct MMA_64x128x32_F32E4M3E4M3_SS_TN +// GMMA 64x128x32 TN S32+=S8*S8 +struct MMA_64x128x32_S32S8S8_RS_TN { using DRegisters = void; - using ARegisters = uint64_t[1]; + using ARegisters = uint32_t[4]; using BRegisters = uint64_t[1]; - using CRegisters = float[64]; + using CRegisters = uint32_t[64]; CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, uint64_t const& desc_b, - float & d00, float & d01, float & d02, float & d03, - float & d04, float & d05, float & d06, float & d07, - float & d08, float & d09, float & d10, float & d11, - float & d12, float & d13, float & d14, float & d15, - float & d16, float & d17, float & d18, float & d19, - float & d20, float & d21, float & d22, float & d23, - float & d24, float & d25, float & d26, float & d27, - float & d28, float & d29, float & d30, float & d31, - float & d32, float & d33, float & d34, float & d35, - float & d36, float & d37, float & d38, float & d39, - float & d40, float & d41, float & d42, float & d43, - float & d44, float & d45, float & d46, float & d47, - float & d48, float & d49, float & d50, float & d51, - float & d52, float & d53, float & d54, float & d55, - float & d56, float & d57, float & d58, float & d59, - float & d60, float & d61, float & d62, float & d63, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %66, 0;\n" - "wgmma.mma_async.sync.aligned.m64n128k32.f32.e4m3.e4m3 " + "setp.ne.b32 p, %69, 0;\n" + "wgmma.mma_async.sync.aligned.m64n128k32.s32.s8.s8 " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " " %16, %17, %18, %19, %20, %21, %22, %23, " @@ -33446,68 +6119,64 @@ struct MMA_64x128x32_F32E4M3E4M3_SS_TN " %40, %41, %42, %43, %44, %45, %46, %47, " " %48, %49, %50, %51, %52, %53, %54, %55, " " %56, %57, %58, %59, %60, %61, %62, %63}," - " %64," - " %65," - " p, %67, %68;\n" + "{%64, %65, %66, %67}," + " %68," + " p;\n" "}\n" - : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), - "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), - "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), - "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), - "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), - "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), - "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), - "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), - "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), - "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), - "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), - "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), - "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), - "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), - "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), - "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63) - : "l"(desc_a), + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), "l"(desc_b), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); + "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x128x32_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x128x32_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// GMMA 64x128x32 TN F32+=E4M3*E4M3 -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -struct MMA_64x128x32_F32E4M3E4M3_RS_TN +// GMMA 64x128x32 TN S32+=S8*S8 +struct MMA_64x128x32_S32S8S8_RS_TN_SATURATE { using DRegisters = void; using ARegisters = uint32_t[4]; using BRegisters = uint64_t[1]; - using CRegisters = float[64]; + using CRegisters = uint32_t[64]; CUTE_HOST_DEVICE static void fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, uint64_t const& desc_b, - float & d00, float & d01, float & d02, float & d03, - float & d04, float & d05, float & d06, float & d07, - float & d08, float & d09, float & d10, float & d11, - float & d12, float & d13, float & d14, float & d15, - float & d16, float & d17, float & d18, float & d19, - float & d20, float & d21, float & d22, float & d23, - float & d24, float & d25, float & d26, float & d27, - float & d28, float & d29, float & d30, float & d31, - float & d32, float & d33, float & d34, float & d35, - float & d36, float & d37, float & d38, float & d39, - float & d40, float & d41, float & d42, float & d43, - float & d44, float & d45, float & d46, float & d47, - float & d48, float & d49, float & d50, float & d51, - float & d52, float & d53, float & d54, float & d55, - float & d56, float & d57, float & d58, float & d59, - float & d60, float & d61, float & d62, float & d63, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) @@ -33516,7 +6185,7 @@ struct MMA_64x128x32_F32E4M3E4M3_RS_TN "{\n" ".reg .pred p;\n" "setp.ne.b32 p, %69, 0;\n" - "wgmma.mma_async.sync.aligned.m64n128k32.f32.e4m3.e4m3 " + "wgmma.mma_async.sync.aligned.m64n128k32.s32.s8.s8.satfinite " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " " %16, %17, %18, %19, %20, %21, %22, %23, " @@ -33527,50 +6196,45 @@ struct MMA_64x128x32_F32E4M3E4M3_RS_TN " %56, %57, %58, %59, %60, %61, %62, %63}," "{%64, %65, %66, %67}," " %68," - " p, %70, %71;\n" + " p;\n" "}\n" - : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), - "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), - "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), - "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), - "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), - "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), - "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), - "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), - "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), - "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), - "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), - "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), - "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), - "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), - "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), - "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63) + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) : "r"(a00), "r"(a01), "r"(a02), "r"(a03), "l"(desc_b), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); + "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x128x32_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x128x32_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x144x32 TN F16+=E4M3*E4M3 -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -struct MMA_64x144x32_F16E4M3E4M3_SS_TN +// GMMA 64x192x32 TN S32+=S8*S8 +struct MMA_64x192x32_S32S8S8_RS_TN { using DRegisters = void; - using ARegisters = uint64_t[1]; + using ARegisters = uint32_t[4]; using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[36]; + using CRegisters = uint32_t[96]; CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, uint64_t const& desc_b, uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, @@ -33581,23 +6245,45 @@ struct MMA_64x144x32_F16E4M3E4M3_SS_TN uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + uint32_t & d88, uint32_t & d89, uint32_t & d90, uint32_t & d91, + uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %38, 0;\n" - "wgmma.mma_async.sync.aligned.m64n144k32.f16.e4m3.e4m3 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35}," - " %36," - " %37," - " p, %39, %40;\n" + "setp.ne.b32 p, %101, 0;\n" + "wgmma.mma_async.sync.aligned.m64n192k32.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + "{%96, %97, %98, %99}," + " %100," + " p;\n" "}\n" : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), @@ -33607,31 +6293,40 @@ struct MMA_64x144x32_F16E4M3E4M3_SS_TN "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35) - : "l"(desc_a), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87), + "+r"(d88), "+r"(d89), "+r"(d90), "+r"(d91), + "+r"(d92), "+r"(d93), "+r"(d94), "+r"(d95) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), "l"(desc_b), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); + "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x144x32_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x192x32_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x144x32 TN F16+=E4M3*E4M3 -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -struct MMA_64x144x32_F16E4M3E4M3_RS_TN +// GMMA 64x192x32 TN S32+=S8*S8 +struct MMA_64x192x32_S32S8S8_RS_TN_SATURATE { using DRegisters = void; using ARegisters = uint32_t[4]; using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[36]; + using CRegisters = uint32_t[96]; CUTE_HOST_DEVICE static void fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, @@ -33645,6 +6340,21 @@ struct MMA_64x144x32_F16E4M3E4M3_RS_TN uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + uint32_t & d88, uint32_t & d89, uint32_t & d90, uint32_t & d91, + uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) @@ -33652,16 +6362,23 @@ struct MMA_64x144x32_F16E4M3E4M3_RS_TN asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %41, 0;\n" - "wgmma.mma_async.sync.aligned.m64n144k32.f16.e4m3.e4m3 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35}," - "{%36, %37, %38, %39}," - " %40," - " p, %42, %43;\n" + "setp.ne.b32 p, %101, 0;\n" + "wgmma.mma_async.sync.aligned.m64n192k32.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + "{%96, %97, %98, %99}," + " %100," + " p;\n" "}\n" : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), @@ -33671,62 +6388,85 @@ struct MMA_64x144x32_F16E4M3E4M3_RS_TN "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35) + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87), + "+r"(d88), "+r"(d89), "+r"(d90), "+r"(d91), + "+r"(d92), "+r"(d93), "+r"(d94), "+r"(d95) : "r"(a00), "r"(a01), "r"(a02), "r"(a03), "l"(desc_b), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); + "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x144x32_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x192x32_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x144x32 TN F32+=E4M3*E4M3 -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -struct MMA_64x144x32_F32E4M3E4M3_SS_TN +// GMMA 64x256x32 TN S32+=S8*S8 +struct MMA_64x256x32_S32S8S8_RS_TN { using DRegisters = void; - using ARegisters = uint64_t[1]; + using ARegisters = uint32_t[4]; using BRegisters = uint64_t[1]; - using CRegisters = float[72]; + using CRegisters = uint32_t[128]; CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, uint64_t const& desc_b, - float & d00, float & d01, float & d02, float & d03, - float & d04, float & d05, float & d06, float & d07, - float & d08, float & d09, float & d10, float & d11, - float & d12, float & d13, float & d14, float & d15, - float & d16, float & d17, float & d18, float & d19, - float & d20, float & d21, float & d22, float & d23, - float & d24, float & d25, float & d26, float & d27, - float & d28, float & d29, float & d30, float & d31, - float & d32, float & d33, float & d34, float & d35, - float & d36, float & d37, float & d38, float & d39, - float & d40, float & d41, float & d42, float & d43, - float & d44, float & d45, float & d46, float & d47, - float & d48, float & d49, float & d50, float & d51, - float & d52, float & d53, float & d54, float & d55, - float & d56, float & d57, float & d58, float & d59, - float & d60, float & d61, float & d62, float & d63, - float & d64, float & d65, float & d66, float & d67, - float & d68, float & d69, float & d70, float & d71, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + uint32_t & d120, uint32_t & d121, uint32_t & d122, uint32_t & d123, + uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %74, 0;\n" - "wgmma.mma_async.sync.aligned.m64n144k32.f32.e4m3.e4m3 " + "setp.ne.b32 p, %133, 0;\n" + "wgmma.mma_async.sync.aligned.m64n256k32.s32.s8.s8 " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " " %16, %17, %18, %19, %20, %21, %22, %23, " @@ -33735,75 +6475,104 @@ struct MMA_64x144x32_F32E4M3E4M3_SS_TN " %40, %41, %42, %43, %44, %45, %46, %47, " " %48, %49, %50, %51, %52, %53, %54, %55, " " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71}," - " %72," - " %73," - " p, %75, %76;\n" + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + "{%128, %129, %130, %131}," + " %132," + " p;\n" "}\n" - : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), - "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), - "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), - "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), - "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), - "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), - "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), - "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), - "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), - "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), - "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), - "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), - "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), - "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), - "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), - "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), - "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), - "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71) - : "l"(desc_a), + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119), + "+r"(d120), "+r"(d121), "+r"(d122), "+r"(d123), + "+r"(d124), "+r"(d125), "+r"(d126), "+r"(d127) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), "l"(desc_b), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); + "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x144x32_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x256x32_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x144x32 TN F32+=E4M3*E4M3 -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -struct MMA_64x144x32_F32E4M3E4M3_RS_TN +// GMMA 64x256x32 TN S32+=S8*S8 +struct MMA_64x256x32_S32S8S8_RS_TN_SATURATE { using DRegisters = void; using ARegisters = uint32_t[4]; using BRegisters = uint64_t[1]; - using CRegisters = float[72]; + using CRegisters = uint32_t[128]; CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, uint64_t const& desc_b, - float & d00, float & d01, float & d02, float & d03, - float & d04, float & d05, float & d06, float & d07, - float & d08, float & d09, float & d10, float & d11, - float & d12, float & d13, float & d14, float & d15, - float & d16, float & d17, float & d18, float & d19, - float & d20, float & d21, float & d22, float & d23, - float & d24, float & d25, float & d26, float & d27, - float & d28, float & d29, float & d30, float & d31, - float & d32, float & d33, float & d34, float & d35, - float & d36, float & d37, float & d38, float & d39, - float & d40, float & d41, float & d42, float & d43, - float & d44, float & d45, float & d46, float & d47, - float & d48, float & d49, float & d50, float & d51, - float & d52, float & d53, float & d54, float & d55, - float & d56, float & d57, float & d58, float & d59, - float & d60, float & d61, float & d62, float & d63, - float & d64, float & d65, float & d66, float & d67, - float & d68, float & d69, float & d70, float & d71, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + uint32_t & d120, uint32_t & d121, uint32_t & d122, uint32_t & d123, + uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) @@ -33811,8 +6580,8 @@ struct MMA_64x144x32_F32E4M3E4M3_RS_TN asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %77, 0;\n" - "wgmma.mma_async.sync.aligned.m64n144k32.f32.e4m3.e4m3 " + "setp.ne.b32 p, %133, 0;\n" + "wgmma.mma_async.sync.aligned.m64n256k32.s32.s8.s8.satfinite " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " " %16, %17, %18, %19, %20, %21, %22, %23, " @@ -33821,67 +6590,73 @@ struct MMA_64x144x32_F32E4M3E4M3_RS_TN " %40, %41, %42, %43, %44, %45, %46, %47, " " %48, %49, %50, %51, %52, %53, %54, %55, " " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71}," - "{%72, %73, %74, %75}," - " %76," - " p, %78, %79;\n" + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + "{%128, %129, %130, %131}," + " %132," + " p;\n" "}\n" - : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), - "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), - "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), - "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), - "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), - "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), - "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), - "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), - "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), - "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), - "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), - "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), - "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), - "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), - "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), - "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), - "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), - "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119), + "+r"(d120), "+r"(d121), "+r"(d122), "+r"(d123), + "+r"(d124), "+r"(d125), "+r"(d126), "+r"(d127) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), "l"(desc_b), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); + "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x144x32_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x256x32_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x160x32 TN F16+=E4M3*E4M3 -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -struct MMA_64x160x32_F16E4M3E4M3_SS_TN +// GMMA 64x8x32 TN S32+=S8*U8 +struct MMA_64x8x32_S32S8U8_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[40]; + using CRegisters = uint32_t[4]; CUTE_HOST_DEVICE static void fma(uint64_t const& desc_a, uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) @@ -33889,141 +6664,76 @@ struct MMA_64x160x32_F16E4M3E4M3_SS_TN asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %42, 0;\n" - "wgmma.mma_async.sync.aligned.m64n160k32.f16.e4m3.e4m3 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39}," - " %40," - " %41," - " p, %43, %44;\n" + "setp.ne.b32 p, %6, 0;\n" + "wgmma.mma_async.sync.aligned.m64n8k32.s32.s8.u8 " + "{%0, %1, %2, %3}," + " %4," + " %5," + " p;\n" "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) : "l"(desc_a), "l"(desc_b), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); + "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x160x32_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x8x32_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x160x32 TN F16+=E4M3*E4M3 -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -struct MMA_64x160x32_F16E4M3E4M3_RS_TN +// GMMA 64x8x32 TN S32+=S8*U8 +struct MMA_64x8x32_S32S8U8_SS_TN_SATURATE { using DRegisters = void; - using ARegisters = uint32_t[4]; + using ARegisters = uint64_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[40]; + using CRegisters = uint32_t[4]; CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + fma(uint64_t const& desc_a, uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %45, 0;\n" - "wgmma.mma_async.sync.aligned.m64n160k32.f16.e4m3.e4m3 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39}," - "{%40, %41, %42, %43}," - " %44," - " p, %46, %47;\n" + "setp.ne.b32 p, %6, 0;\n" + "wgmma.mma_async.sync.aligned.m64n8k32.s32.s8.u8.satfinite " + "{%0, %1, %2, %3}," + " %4," + " %5," + " p;\n" "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) + : "l"(desc_a), "l"(desc_b), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); + "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x160x32_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x8x32_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x160x32 TN F32+=E4M3*E4M3 -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -struct MMA_64x160x32_F32E4M3E4M3_SS_TN +// GMMA 64x16x32 TN S32+=S8*U8 +struct MMA_64x16x32_S32S8U8_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = float[80]; + using CRegisters = uint32_t[8]; CUTE_HOST_DEVICE static void fma(uint64_t const& desc_a, uint64_t const& desc_b, - float & d00, float & d01, float & d02, float & d03, - float & d04, float & d05, float & d06, float & d07, - float & d08, float & d09, float & d10, float & d11, - float & d12, float & d13, float & d14, float & d15, - float & d16, float & d17, float & d18, float & d19, - float & d20, float & d21, float & d22, float & d23, - float & d24, float & d25, float & d26, float & d27, - float & d28, float & d29, float & d30, float & d31, - float & d32, float & d33, float & d34, float & d35, - float & d36, float & d37, float & d38, float & d39, - float & d40, float & d41, float & d42, float & d43, - float & d44, float & d45, float & d46, float & d47, - float & d48, float & d49, float & d50, float & d51, - float & d52, float & d53, float & d54, float & d55, - float & d56, float & d57, float & d58, float & d59, - float & d60, float & d61, float & d62, float & d63, - float & d64, float & d65, float & d66, float & d67, - float & d68, float & d69, float & d70, float & d71, - float & d72, float & d73, float & d74, float & d75, - float & d76, float & d77, float & d78, float & d79, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) @@ -34031,157 +6741,73 @@ struct MMA_64x160x32_F32E4M3E4M3_SS_TN asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %82, 0;\n" - "wgmma.mma_async.sync.aligned.m64n160k32.f32.e4m3.e4m3 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79}," - " %80," - " %81," - " p, %83, %84;\n" + "setp.ne.b32 p, %10, 0;\n" + "wgmma.mma_async.sync.aligned.m64n16k32.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + " %8," + " %9," + " p;\n" "}\n" - : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), - "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), - "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), - "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), - "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), - "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), - "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), - "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), - "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), - "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), - "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), - "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), - "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), - "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), - "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), - "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), - "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), - "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), - "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), - "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79) + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) : "l"(desc_a), "l"(desc_b), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); + "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x160x32_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x16x32_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x160x32 TN F32+=E4M3*E4M3 -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -struct MMA_64x160x32_F32E4M3E4M3_RS_TN +// GMMA 64x16x32 TN S32+=S8*U8 +struct MMA_64x16x32_S32S8U8_SS_TN_SATURATE { using DRegisters = void; - using ARegisters = uint32_t[4]; + using ARegisters = uint64_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = float[80]; + using CRegisters = uint32_t[8]; CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + fma(uint64_t const& desc_a, uint64_t const& desc_b, - float & d00, float & d01, float & d02, float & d03, - float & d04, float & d05, float & d06, float & d07, - float & d08, float & d09, float & d10, float & d11, - float & d12, float & d13, float & d14, float & d15, - float & d16, float & d17, float & d18, float & d19, - float & d20, float & d21, float & d22, float & d23, - float & d24, float & d25, float & d26, float & d27, - float & d28, float & d29, float & d30, float & d31, - float & d32, float & d33, float & d34, float & d35, - float & d36, float & d37, float & d38, float & d39, - float & d40, float & d41, float & d42, float & d43, - float & d44, float & d45, float & d46, float & d47, - float & d48, float & d49, float & d50, float & d51, - float & d52, float & d53, float & d54, float & d55, - float & d56, float & d57, float & d58, float & d59, - float & d60, float & d61, float & d62, float & d63, - float & d64, float & d65, float & d66, float & d67, - float & d68, float & d69, float & d70, float & d71, - float & d72, float & d73, float & d74, float & d75, - float & d76, float & d77, float & d78, float & d79, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %85, 0;\n" - "wgmma.mma_async.sync.aligned.m64n160k32.f32.e4m3.e4m3 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79}," - "{%80, %81, %82, %83}," - " %84," - " p, %86, %87;\n" + "setp.ne.b32 p, %10, 0;\n" + "wgmma.mma_async.sync.aligned.m64n16k32.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + " %8," + " %9," + " p;\n" "}\n" - : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), - "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), - "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), - "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), - "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), - "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), - "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), - "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), - "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), - "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), - "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), - "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), - "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), - "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), - "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), - "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), - "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), - "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), - "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), - "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) + : "l"(desc_a), "l"(desc_b), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); + "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x160x32_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x16x32_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x176x32 TN F16+=E4M3*E4M3 -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -struct MMA_64x176x32_F16E4M3E4M3_SS_TN +// GMMA 64x32x32 TN S32+=S8*U8 +struct MMA_64x32x32_S32S8U8_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[44]; + using CRegisters = uint32_t[16]; CUTE_HOST_DEVICE static void fma(uint64_t const& desc_a, @@ -34190,13 +6816,6 @@ struct MMA_64x176x32_F16E4M3E4M3_SS_TN uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) @@ -34204,148 +6823,93 @@ struct MMA_64x176x32_F16E4M3E4M3_SS_TN asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %46, 0;\n" - "wgmma.mma_async.sync.aligned.m64n176k32.f16.e4m3.e4m3 " + "setp.ne.b32 p, %18, 0;\n" + "wgmma.mma_async.sync.aligned.m64n32k32.s32.s8.u8 " "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43}," - " %44," - " %45," - " p, %47, %48;\n" + " %8, %9, %10, %11, %12, %13, %14, %15}," + " %16," + " %17," + " p;\n" "}\n" : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43) + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) : "l"(desc_a), "l"(desc_b), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); + "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x176x32_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x32x32_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x176x32 TN F16+=E4M3*E4M3 -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -struct MMA_64x176x32_F16E4M3E4M3_RS_TN +// GMMA 64x32x32 TN S32+=S8*U8 +struct MMA_64x32x32_S32S8U8_SS_TN_SATURATE { using DRegisters = void; - using ARegisters = uint32_t[4]; + using ARegisters = uint64_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[44]; + using CRegisters = uint32_t[16]; CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + fma(uint64_t const& desc_a, uint64_t const& desc_b, uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %49, 0;\n" - "wgmma.mma_async.sync.aligned.m64n176k32.f16.e4m3.e4m3 " + "setp.ne.b32 p, %18, 0;\n" + "wgmma.mma_async.sync.aligned.m64n32k32.s32.s8.u8.satfinite " "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43}," - "{%44, %45, %46, %47}," - " %48," - " p, %50, %51;\n" + " %8, %9, %10, %11, %12, %13, %14, %15}," + " %16," + " %17," + " p;\n" "}\n" : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) + : "l"(desc_a), "l"(desc_b), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); + "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x176x32_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x32x32_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x176x32 TN F32+=E4M3*E4M3 -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -struct MMA_64x176x32_F32E4M3E4M3_SS_TN +// GMMA 64x64x32 TN S32+=S8*U8 +struct MMA_64x64x32_S32S8U8_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = float[88]; + using CRegisters = uint32_t[32]; CUTE_HOST_DEVICE static void fma(uint64_t const& desc_a, uint64_t const& desc_b, - float & d00, float & d01, float & d02, float & d03, - float & d04, float & d05, float & d06, float & d07, - float & d08, float & d09, float & d10, float & d11, - float & d12, float & d13, float & d14, float & d15, - float & d16, float & d17, float & d18, float & d19, - float & d20, float & d21, float & d22, float & d23, - float & d24, float & d25, float & d26, float & d27, - float & d28, float & d29, float & d30, float & d31, - float & d32, float & d33, float & d34, float & d35, - float & d36, float & d37, float & d38, float & d39, - float & d40, float & d41, float & d42, float & d43, - float & d44, float & d45, float & d46, float & d47, - float & d48, float & d49, float & d50, float & d51, - float & d52, float & d53, float & d54, float & d55, - float & d56, float & d57, float & d58, float & d59, - float & d60, float & d61, float & d62, float & d63, - float & d64, float & d65, float & d66, float & d67, - float & d68, float & d69, float & d70, float & d71, - float & d72, float & d73, float & d74, float & d75, - float & d76, float & d77, float & d78, float & d79, - float & d80, float & d81, float & d82, float & d83, - float & d84, float & d85, float & d86, float & d87, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) @@ -34353,159 +6917,92 @@ struct MMA_64x176x32_F32E4M3E4M3_SS_TN asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %90, 0;\n" - "wgmma.mma_async.sync.aligned.m64n176k32.f32.e4m3.e4m3 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87}," - " %88," - " %89," - " p, %91, %92;\n" + "setp.ne.b32 p, %34, 0;\n" + "wgmma.mma_async.sync.aligned.m64n64k32.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + " %32," + " %33," + " p;\n" "}\n" - : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), - "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), - "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), - "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), - "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), - "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), - "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), - "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), - "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), - "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), - "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), - "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), - "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), - "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), - "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), - "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), - "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), - "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), - "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), - "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), - "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), - "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87) + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) : "l"(desc_a), "l"(desc_b), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); + "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x176x32_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x64x32_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x176x32 TN F32+=E4M3*E4M3 -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -struct MMA_64x176x32_F32E4M3E4M3_RS_TN +// GMMA 64x64x32 TN S32+=S8*U8 +struct MMA_64x64x32_S32S8U8_SS_TN_SATURATE { using DRegisters = void; - using ARegisters = uint32_t[4]; + using ARegisters = uint64_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = float[88]; + using CRegisters = uint32_t[32]; CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + fma(uint64_t const& desc_a, uint64_t const& desc_b, - float & d00, float & d01, float & d02, float & d03, - float & d04, float & d05, float & d06, float & d07, - float & d08, float & d09, float & d10, float & d11, - float & d12, float & d13, float & d14, float & d15, - float & d16, float & d17, float & d18, float & d19, - float & d20, float & d21, float & d22, float & d23, - float & d24, float & d25, float & d26, float & d27, - float & d28, float & d29, float & d30, float & d31, - float & d32, float & d33, float & d34, float & d35, - float & d36, float & d37, float & d38, float & d39, - float & d40, float & d41, float & d42, float & d43, - float & d44, float & d45, float & d46, float & d47, - float & d48, float & d49, float & d50, float & d51, - float & d52, float & d53, float & d54, float & d55, - float & d56, float & d57, float & d58, float & d59, - float & d60, float & d61, float & d62, float & d63, - float & d64, float & d65, float & d66, float & d67, - float & d68, float & d69, float & d70, float & d71, - float & d72, float & d73, float & d74, float & d75, - float & d76, float & d77, float & d78, float & d79, - float & d80, float & d81, float & d82, float & d83, - float & d84, float & d85, float & d86, float & d87, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %93, 0;\n" - "wgmma.mma_async.sync.aligned.m64n176k32.f32.e4m3.e4m3 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87}," - "{%88, %89, %90, %91}," - " %92," - " p, %94, %95;\n" + "setp.ne.b32 p, %34, 0;\n" + "wgmma.mma_async.sync.aligned.m64n64k32.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + " %32," + " %33," + " p;\n" "}\n" - : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), - "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), - "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), - "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), - "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), - "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), - "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), - "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), - "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), - "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), - "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), - "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), - "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), - "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), - "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), - "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), - "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), - "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), - "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), - "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), - "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), - "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) + : "l"(desc_a), "l"(desc_b), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); + "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x176x32_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x64x32_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -// GMMA 64x192x32 TN F16+=E4M3*E4M3 -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -struct MMA_64x192x32_F16E4M3E4M3_SS_TN +// GMMA 64x96x32 TN S32+=S8*U8 +struct MMA_64x96x32_S32S8U8_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -34535,7 +7032,7 @@ struct MMA_64x192x32_F16E4M3E4M3_SS_TN "{\n" ".reg .pred p;\n" "setp.ne.b32 p, %50, 0;\n" - "wgmma.mma_async.sync.aligned.m64n192k32.f16.e4m3.e4m3 " + "wgmma.mma_async.sync.aligned.m64n96k32.s32.s8.u8 " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " " %16, %17, %18, %19, %20, %21, %22, %23, " @@ -34544,7 +7041,7 @@ struct MMA_64x192x32_F16E4M3E4M3_SS_TN " %40, %41, %42, %43, %44, %45, %46, %47}," " %48," " %49," - " p, %51, %52;\n" + " p;\n" "}\n" : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), @@ -34560,29 +7057,25 @@ struct MMA_64x192x32_F16E4M3E4M3_SS_TN "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) : "l"(desc_a), "l"(desc_b), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); + "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x192x32_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x96x32_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// GMMA 64x192x32 TN F16+=E4M3*E4M3 -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -struct MMA_64x192x32_F16E4M3E4M3_RS_TN +// GMMA 64x96x32 TN S32+=S8*U8 +struct MMA_64x96x32_S32S8U8_SS_TN_SATURATE { using DRegisters = void; - using ARegisters = uint32_t[4]; + using ARegisters = uint64_t[1]; using BRegisters = uint64_t[1]; using CRegisters = uint32_t[48]; CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + fma(uint64_t const& desc_a, uint64_t const& desc_b, uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, @@ -34599,21 +7092,21 @@ struct MMA_64x192x32_F16E4M3E4M3_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %53, 0;\n" - "wgmma.mma_async.sync.aligned.m64n192k32.f16.e4m3.e4m3 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47}," - "{%48, %49, %50, %51}," - " %52," - " p, %54, %55;\n" + "setp.ne.b32 p, %50, 0;\n" + "wgmma.mma_async.sync.aligned.m64n96k32.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + " %48," + " %49," + " p;\n" "}\n" : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), @@ -34627,56 +7120,44 @@ struct MMA_64x192x32_F16E4M3E4M3_RS_TN "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + : "l"(desc_a), "l"(desc_b), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); + "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x192x32_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x96x32_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// GMMA 64x192x32 TN F32+=E4M3*E4M3 -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -struct MMA_64x192x32_F32E4M3E4M3_SS_TN +// GMMA 64x128x32 TN S32+=S8*U8 +struct MMA_64x128x32_S32S8U8_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = float[96]; + using CRegisters = uint32_t[64]; CUTE_HOST_DEVICE static void fma(uint64_t const& desc_a, uint64_t const& desc_b, - float & d00, float & d01, float & d02, float & d03, - float & d04, float & d05, float & d06, float & d07, - float & d08, float & d09, float & d10, float & d11, - float & d12, float & d13, float & d14, float & d15, - float & d16, float & d17, float & d18, float & d19, - float & d20, float & d21, float & d22, float & d23, - float & d24, float & d25, float & d26, float & d27, - float & d28, float & d29, float & d30, float & d31, - float & d32, float & d33, float & d34, float & d35, - float & d36, float & d37, float & d38, float & d39, - float & d40, float & d41, float & d42, float & d43, - float & d44, float & d45, float & d46, float & d47, - float & d48, float & d49, float & d50, float & d51, - float & d52, float & d53, float & d54, float & d55, - float & d56, float & d57, float & d58, float & d59, - float & d60, float & d61, float & d62, float & d63, - float & d64, float & d65, float & d66, float & d67, - float & d68, float & d69, float & d70, float & d71, - float & d72, float & d73, float & d74, float & d75, - float & d76, float & d77, float & d78, float & d79, - float & d80, float & d81, float & d82, float & d83, - float & d84, float & d85, float & d86, float & d87, - float & d88, float & d89, float & d90, float & d91, - float & d92, float & d93, float & d94, float & d95, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) @@ -34684,8 +7165,8 @@ struct MMA_64x192x32_F32E4M3E4M3_SS_TN asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %98, 0;\n" - "wgmma.mma_async.sync.aligned.m64n192k32.f32.e4m3.e4m3 " + "setp.ne.b32 p, %66, 0;\n" + "wgmma.mma_async.sync.aligned.m64n128k32.s32.s8.u8 " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " " %16, %17, %18, %19, %20, %21, %22, %23, " @@ -34693,98 +7174,74 @@ struct MMA_64x192x32_F32E4M3E4M3_SS_TN " %32, %33, %34, %35, %36, %37, %38, %39, " " %40, %41, %42, %43, %44, %45, %46, %47, " " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87, " - " %88, %89, %90, %91, %92, %93, %94, %95}," - " %96," - " %97," - " p, %99, %100;\n" + " %56, %57, %58, %59, %60, %61, %62, %63}," + " %64," + " %65," + " p;\n" "}\n" - : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), - "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), - "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), - "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), - "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), - "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), - "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), - "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), - "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), - "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), - "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), - "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), - "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), - "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), - "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), - "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), - "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), - "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), - "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), - "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), - "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), - "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), - "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91), - "+f"(d92), "+f"(d93), "+f"(d94), "+f"(d95) + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) : "l"(desc_a), "l"(desc_b), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); + "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x192x32_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x128x32_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// GMMA 64x192x32 TN F32+=E4M3*E4M3 -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -struct MMA_64x192x32_F32E4M3E4M3_RS_TN +// GMMA 64x128x32 TN S32+=S8*U8 +struct MMA_64x128x32_S32S8U8_SS_TN_SATURATE { using DRegisters = void; - using ARegisters = uint32_t[4]; + using ARegisters = uint64_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = float[96]; + using CRegisters = uint32_t[64]; CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + fma(uint64_t const& desc_a, uint64_t const& desc_b, - float & d00, float & d01, float & d02, float & d03, - float & d04, float & d05, float & d06, float & d07, - float & d08, float & d09, float & d10, float & d11, - float & d12, float & d13, float & d14, float & d15, - float & d16, float & d17, float & d18, float & d19, - float & d20, float & d21, float & d22, float & d23, - float & d24, float & d25, float & d26, float & d27, - float & d28, float & d29, float & d30, float & d31, - float & d32, float & d33, float & d34, float & d35, - float & d36, float & d37, float & d38, float & d39, - float & d40, float & d41, float & d42, float & d43, - float & d44, float & d45, float & d46, float & d47, - float & d48, float & d49, float & d50, float & d51, - float & d52, float & d53, float & d54, float & d55, - float & d56, float & d57, float & d58, float & d59, - float & d60, float & d61, float & d62, float & d63, - float & d64, float & d65, float & d66, float & d67, - float & d68, float & d69, float & d70, float & d71, - float & d72, float & d73, float & d74, float & d75, - float & d76, float & d77, float & d78, float & d79, - float & d80, float & d81, float & d82, float & d83, - float & d84, float & d85, float & d86, float & d87, - float & d88, float & d89, float & d90, float & d91, - float & d92, float & d93, float & d94, float & d95, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %101, 0;\n" - "wgmma.mma_async.sync.aligned.m64n192k32.f32.e4m3.e4m3 " + "setp.ne.b32 p, %66, 0;\n" + "wgmma.mma_async.sync.aligned.m64n128k32.s32.s8.u8.satfinite " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " " %16, %17, %18, %19, %20, %21, %22, %23, " @@ -34792,62 +7249,45 @@ struct MMA_64x192x32_F32E4M3E4M3_RS_TN " %32, %33, %34, %35, %36, %37, %38, %39, " " %40, %41, %42, %43, %44, %45, %46, %47, " " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87, " - " %88, %89, %90, %91, %92, %93, %94, %95}," - "{%96, %97, %98, %99}," - " %100," - " p, %102, %103;\n" + " %56, %57, %58, %59, %60, %61, %62, %63}," + " %64," + " %65," + " p;\n" "}\n" - : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), - "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), - "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), - "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), - "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), - "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), - "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), - "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), - "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), - "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), - "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), - "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), - "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), - "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), - "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), - "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), - "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), - "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), - "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), - "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), - "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), - "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), - "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91), - "+f"(d92), "+f"(d93), "+f"(d94), "+f"(d95) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) + : "l"(desc_a), "l"(desc_b), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); + "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x192x32_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x128x32_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x208x32 TN F16+=E4M3*E4M3 -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -struct MMA_64x208x32_F16E4M3E4M3_SS_TN +// GMMA 64x192x32 TN S32+=S8*U8 +struct MMA_64x192x32_S32S8U8_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[52]; + using CRegisters = uint32_t[96]; CUTE_HOST_DEVICE static void fma(uint64_t const& desc_a, @@ -34865,6 +7305,17 @@ struct MMA_64x208x32_F16E4M3E4M3_SS_TN uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + uint32_t & d88, uint32_t & d89, uint32_t & d90, uint32_t & d91, + uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) @@ -34872,18 +7323,23 @@ struct MMA_64x208x32_F16E4M3E4M3_SS_TN asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %54, 0;\n" - "wgmma.mma_async.sync.aligned.m64n208k32.f16.e4m3.e4m3 " + "setp.ne.b32 p, %98, 0;\n" + "wgmma.mma_async.sync.aligned.m64n192k32.s32.s8.u8 " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " " %16, %17, %18, %19, %20, %21, %22, %23, " " %24, %25, %26, %27, %28, %29, %30, %31, " " %32, %33, %34, %35, %36, %37, %38, %39, " " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51}," - " %52," - " %53," - " p, %55, %56;\n" + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + " %96," + " %97," + " p;\n" "}\n" : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), @@ -34897,34 +7353,39 @@ struct MMA_64x208x32_F16E4M3E4M3_SS_TN "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), - "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51) + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87), + "+r"(d88), "+r"(d89), "+r"(d90), "+r"(d91), + "+r"(d92), "+r"(d93), "+r"(d94), "+r"(d95) : "l"(desc_a), "l"(desc_b), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); + "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x208x32_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x192x32_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x208x32 TN F16+=E4M3*E4M3 -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -struct MMA_64x208x32_F16E4M3E4M3_RS_TN +// GMMA 64x192x32 TN S32+=S8*U8 +struct MMA_64x192x32_S32S8U8_SS_TN_SATURATE { using DRegisters = void; - using ARegisters = uint32_t[4]; + using ARegisters = uint64_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[52]; + using CRegisters = uint32_t[96]; CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + fma(uint64_t const& desc_a, uint64_t const& desc_b, uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, @@ -34939,25 +7400,41 @@ struct MMA_64x208x32_F16E4M3E4M3_RS_TN uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + uint32_t & d88, uint32_t & d89, uint32_t & d90, uint32_t & d91, + uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %57, 0;\n" - "wgmma.mma_async.sync.aligned.m64n208k32.f16.e4m3.e4m3 " + "setp.ne.b32 p, %98, 0;\n" + "wgmma.mma_async.sync.aligned.m64n192k32.s32.s8.u8.satfinite " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " " %16, %17, %18, %19, %20, %21, %22, %23, " " %24, %25, %26, %27, %28, %29, %30, %31, " " %32, %33, %34, %35, %36, %37, %38, %39, " " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51}," - "{%52, %53, %54, %55}," - " %56," - " p, %58, %59;\n" + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + " %96," + " %97," + " p;\n" "}\n" : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), @@ -34971,61 +7448,72 @@ struct MMA_64x208x32_F16E4M3E4M3_RS_TN "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), - "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87), + "+r"(d88), "+r"(d89), "+r"(d90), "+r"(d91), + "+r"(d92), "+r"(d93), "+r"(d94), "+r"(d95) + : "l"(desc_a), "l"(desc_b), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); + "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x208x32_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x192x32_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x208x32 TN F32+=E4M3*E4M3 -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -struct MMA_64x208x32_F32E4M3E4M3_SS_TN +// GMMA 64x256x32 TN S32+=S8*U8 +struct MMA_64x256x32_S32S8U8_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = float[104]; + using CRegisters = uint32_t[128]; CUTE_HOST_DEVICE static void fma(uint64_t const& desc_a, uint64_t const& desc_b, - float & d000, float & d001, float & d002, float & d003, - float & d004, float & d005, float & d006, float & d007, - float & d008, float & d009, float & d010, float & d011, - float & d012, float & d013, float & d014, float & d015, - float & d016, float & d017, float & d018, float & d019, - float & d020, float & d021, float & d022, float & d023, - float & d024, float & d025, float & d026, float & d027, - float & d028, float & d029, float & d030, float & d031, - float & d032, float & d033, float & d034, float & d035, - float & d036, float & d037, float & d038, float & d039, - float & d040, float & d041, float & d042, float & d043, - float & d044, float & d045, float & d046, float & d047, - float & d048, float & d049, float & d050, float & d051, - float & d052, float & d053, float & d054, float & d055, - float & d056, float & d057, float & d058, float & d059, - float & d060, float & d061, float & d062, float & d063, - float & d064, float & d065, float & d066, float & d067, - float & d068, float & d069, float & d070, float & d071, - float & d072, float & d073, float & d074, float & d075, - float & d076, float & d077, float & d078, float & d079, - float & d080, float & d081, float & d082, float & d083, - float & d084, float & d085, float & d086, float & d087, - float & d088, float & d089, float & d090, float & d091, - float & d092, float & d093, float & d094, float & d095, - float & d096, float & d097, float & d098, float & d099, - float & d100, float & d101, float & d102, float & d103, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + uint32_t & d120, uint32_t & d121, uint32_t & d122, uint32_t & d123, + uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) @@ -35033,8 +7521,8 @@ struct MMA_64x208x32_F32E4M3E4M3_SS_TN asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %106, 0;\n" - "wgmma.mma_async.sync.aligned.m64n208k32.f32.e4m3.e4m3 " + "setp.ne.b32 p, %130, 0;\n" + "wgmma.mma_async.sync.aligned.m64n256k32.s32.s8.u8 " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " " %16, %17, %18, %19, %20, %21, %22, %23, " @@ -35047,100 +7535,109 @@ struct MMA_64x208x32_F32E4M3E4M3_SS_TN " %72, %73, %74, %75, %76, %77, %78, %79, " " %80, %81, %82, %83, %84, %85, %86, %87, " " %88, %89, %90, %91, %92, %93, %94, %95, " - " %96, %97, %98, %99, %100, %101, %102, %103}," - " %104," - " %105," - " p, %107, %108;\n" + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + " %128," + " %129," + " p;\n" "}\n" - : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), - "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), - "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), - "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), - "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), - "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), - "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), - "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), - "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), - "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), - "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), - "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), - "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), - "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), - "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), - "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), - "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), - "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), - "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), - "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), - "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), - "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), - "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), - "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), - "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), - "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103) + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119), + "+r"(d120), "+r"(d121), "+r"(d122), "+r"(d123), + "+r"(d124), "+r"(d125), "+r"(d126), "+r"(d127) : "l"(desc_a), "l"(desc_b), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); + "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x208x32_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x256x32_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x208x32 TN F32+=E4M3*E4M3 -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -struct MMA_64x208x32_F32E4M3E4M3_RS_TN +// GMMA 64x256x32 TN S32+=S8*U8 +struct MMA_64x256x32_S32S8U8_SS_TN_SATURATE { using DRegisters = void; - using ARegisters = uint32_t[4]; + using ARegisters = uint64_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = float[104]; + using CRegisters = uint32_t[128]; CUTE_HOST_DEVICE static void - fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + fma(uint64_t const& desc_a, uint64_t const& desc_b, - float & d000, float & d001, float & d002, float & d003, - float & d004, float & d005, float & d006, float & d007, - float & d008, float & d009, float & d010, float & d011, - float & d012, float & d013, float & d014, float & d015, - float & d016, float & d017, float & d018, float & d019, - float & d020, float & d021, float & d022, float & d023, - float & d024, float & d025, float & d026, float & d027, - float & d028, float & d029, float & d030, float & d031, - float & d032, float & d033, float & d034, float & d035, - float & d036, float & d037, float & d038, float & d039, - float & d040, float & d041, float & d042, float & d043, - float & d044, float & d045, float & d046, float & d047, - float & d048, float & d049, float & d050, float & d051, - float & d052, float & d053, float & d054, float & d055, - float & d056, float & d057, float & d058, float & d059, - float & d060, float & d061, float & d062, float & d063, - float & d064, float & d065, float & d066, float & d067, - float & d068, float & d069, float & d070, float & d071, - float & d072, float & d073, float & d074, float & d075, - float & d076, float & d077, float & d078, float & d079, - float & d080, float & d081, float & d082, float & d083, - float & d084, float & d085, float & d086, float & d087, - float & d088, float & d089, float & d090, float & d091, - float & d092, float & d093, float & d094, float & d095, - float & d096, float & d097, float & d098, float & d099, - float & d100, float & d101, float & d102, float & d103, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + uint32_t & d120, uint32_t & d121, uint32_t & d122, uint32_t & d123, + uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %109, 0;\n" - "wgmma.mma_async.sync.aligned.m64n208k32.f32.e4m3.e4m3 " + "setp.ne.b32 p, %130, 0;\n" + "wgmma.mma_async.sync.aligned.m64n256k32.s32.s8.u8.satfinite " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " " %16, %17, %18, %19, %20, %21, %22, %23, " @@ -35153,155 +7650,146 @@ struct MMA_64x208x32_F32E4M3E4M3_RS_TN " %72, %73, %74, %75, %76, %77, %78, %79, " " %80, %81, %82, %83, %84, %85, %86, %87, " " %88, %89, %90, %91, %92, %93, %94, %95, " - " %96, %97, %98, %99, %100, %101, %102, %103}," - "{%104, %105, %106, %107}," - " %108," - " p, %110, %111;\n" + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + " %128," + " %129," + " p;\n" "}\n" - : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), - "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), - "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), - "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), - "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), - "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), - "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), - "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), - "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), - "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), - "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), - "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), - "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), - "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), - "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), - "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), - "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), - "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), - "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), - "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), - "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), - "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), - "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), - "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), - "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), - "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103) - : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119), + "+r"(d120), "+r"(d121), "+r"(d122), "+r"(d123), + "+r"(d124), "+r"(d125), "+r"(d126), "+r"(d127) + : "l"(desc_a), "l"(desc_b), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); + "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x208x32_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x256x32_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x224x32 TN F16+=E4M3*E4M3 -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -struct MMA_64x224x32_F16E4M3E4M3_SS_TN +// GMMA 64x8x32 TN S32+=S8*U8 +struct MMA_64x8x32_S32S8U8_RS_TN { using DRegisters = void; - using ARegisters = uint64_t[1]; + using ARegisters = uint32_t[4]; using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[56]; + using CRegisters = uint32_t[4]; CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, - uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, - uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %58, 0;\n" - "wgmma.mma_async.sync.aligned.m64n224k32.f16.e4m3.e4m3 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55}," - " %56," - " %57," - " p, %59, %60;\n" + "setp.ne.b32 p, %9, 0;\n" + "wgmma.mma_async.sync.aligned.m64n8k32.s32.s8.u8 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + " %8," + " p;\n" "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), - "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), - "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), - "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) - : "l"(desc_a), + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), "l"(desc_b), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); + "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x224x32_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x8x32_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x8x32 TN S32+=S8*U8 +struct MMA_64x8x32_S32S8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %9, 0;\n" + "wgmma.mma_async.sync.aligned.m64n8k32.s32.s8.u8.satfinite " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + " %8," + " p;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x8x32_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif + } +}; //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x224x32 TN F16+=E4M3*E4M3 -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -struct MMA_64x224x32_F16E4M3E4M3_RS_TN +// GMMA 64x16x32 TN S32+=S8*U8 +struct MMA_64x16x32_S32S8U8_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[56]; + using CRegisters = uint32_t[8]; CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, - uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, - uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) @@ -35309,200 +7797,126 @@ struct MMA_64x224x32_F16E4M3E4M3_RS_TN asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %61, 0;\n" - "wgmma.mma_async.sync.aligned.m64n224k32.f16.e4m3.e4m3 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55}," - "{%56, %57, %58, %59}," - " %60," - " p, %62, %63;\n" + "setp.ne.b32 p, %13, 0;\n" + "wgmma.mma_async.sync.aligned.m64n16k32.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + "{%8, %9, %10, %11}," + " %12," + " p;\n" "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), - "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), - "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), - "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), "l"(desc_b), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); + "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x224x32_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x16x32_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x16x32 TN S32+=S8*U8 +struct MMA_64x16x32_S32S8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[8]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %13, 0;\n" + "wgmma.mma_async.sync.aligned.m64n16k32.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + "{%8, %9, %10, %11}," + " %12," + " p;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x16x32_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif + } +}; //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x224x32 TN F32+=E4M3*E4M3 -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -struct MMA_64x224x32_F32E4M3E4M3_SS_TN +// GMMA 64x32x32 TN S32+=S8*U8 +struct MMA_64x32x32_S32S8U8_RS_TN { using DRegisters = void; - using ARegisters = uint64_t[1]; + using ARegisters = uint32_t[4]; using BRegisters = uint64_t[1]; - using CRegisters = float[112]; + using CRegisters = uint32_t[16]; CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, uint64_t const& desc_b, - float & d000, float & d001, float & d002, float & d003, - float & d004, float & d005, float & d006, float & d007, - float & d008, float & d009, float & d010, float & d011, - float & d012, float & d013, float & d014, float & d015, - float & d016, float & d017, float & d018, float & d019, - float & d020, float & d021, float & d022, float & d023, - float & d024, float & d025, float & d026, float & d027, - float & d028, float & d029, float & d030, float & d031, - float & d032, float & d033, float & d034, float & d035, - float & d036, float & d037, float & d038, float & d039, - float & d040, float & d041, float & d042, float & d043, - float & d044, float & d045, float & d046, float & d047, - float & d048, float & d049, float & d050, float & d051, - float & d052, float & d053, float & d054, float & d055, - float & d056, float & d057, float & d058, float & d059, - float & d060, float & d061, float & d062, float & d063, - float & d064, float & d065, float & d066, float & d067, - float & d068, float & d069, float & d070, float & d071, - float & d072, float & d073, float & d074, float & d075, - float & d076, float & d077, float & d078, float & d079, - float & d080, float & d081, float & d082, float & d083, - float & d084, float & d085, float & d086, float & d087, - float & d088, float & d089, float & d090, float & d091, - float & d092, float & d093, float & d094, float & d095, - float & d096, float & d097, float & d098, float & d099, - float & d100, float & d101, float & d102, float & d103, - float & d104, float & d105, float & d106, float & d107, - float & d108, float & d109, float & d110, float & d111, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %114, 0;\n" - "wgmma.mma_async.sync.aligned.m64n224k32.f32.e4m3.e4m3 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87, " - " %88, %89, %90, %91, %92, %93, %94, %95, " - " %96, %97, %98, %99, %100, %101, %102, %103, " - " %104, %105, %106, %107, %108, %109, %110, %111}," - " %112," - " %113," - " p, %115, %116;\n" + "setp.ne.b32 p, %21, 0;\n" + "wgmma.mma_async.sync.aligned.m64n32k32.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + "{%16, %17, %18, %19}," + " %20," + " p;\n" "}\n" - : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), - "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), - "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), - "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), - "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), - "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), - "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), - "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), - "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), - "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), - "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), - "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), - "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), - "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), - "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), - "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), - "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), - "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), - "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), - "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), - "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), - "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), - "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), - "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), - "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), - "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), - "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), - "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111) - : "l"(desc_a), + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), "l"(desc_b), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); + "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x224x32_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x32x32_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x224x32 TN F32+=E4M3*E4M3 -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -struct MMA_64x224x32_F32E4M3E4M3_RS_TN +// GMMA 64x32x32 TN S32+=S8*U8 +struct MMA_64x32x32_S32S8U8_RS_TN_SATURATE { using DRegisters = void; using ARegisters = uint32_t[4]; using BRegisters = uint64_t[1]; - using CRegisters = float[112]; + using CRegisters = uint32_t[16]; CUTE_HOST_DEVICE static void - fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, uint64_t const& desc_b, - float & d000, float & d001, float & d002, float & d003, - float & d004, float & d005, float & d006, float & d007, - float & d008, float & d009, float & d010, float & d011, - float & d012, float & d013, float & d014, float & d015, - float & d016, float & d017, float & d018, float & d019, - float & d020, float & d021, float & d022, float & d023, - float & d024, float & d025, float & d026, float & d027, - float & d028, float & d029, float & d030, float & d031, - float & d032, float & d033, float & d034, float & d035, - float & d036, float & d037, float & d038, float & d039, - float & d040, float & d041, float & d042, float & d043, - float & d044, float & d045, float & d046, float & d047, - float & d048, float & d049, float & d050, float & d051, - float & d052, float & d053, float & d054, float & d055, - float & d056, float & d057, float & d058, float & d059, - float & d060, float & d061, float & d062, float & d063, - float & d064, float & d065, float & d066, float & d067, - float & d068, float & d069, float & d070, float & d071, - float & d072, float & d073, float & d074, float & d075, - float & d076, float & d077, float & d078, float & d079, - float & d080, float & d081, float & d082, float & d083, - float & d084, float & d085, float & d086, float & d087, - float & d088, float & d089, float & d090, float & d091, - float & d092, float & d093, float & d094, float & d095, - float & d096, float & d097, float & d098, float & d099, - float & d100, float & d101, float & d102, float & d103, - float & d104, float & d105, float & d106, float & d107, - float & d108, float & d109, float & d110, float & d111, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) @@ -35510,81 +7924,39 @@ struct MMA_64x224x32_F32E4M3E4M3_RS_TN asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %117, 0;\n" - "wgmma.mma_async.sync.aligned.m64n224k32.f32.e4m3.e4m3 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87, " - " %88, %89, %90, %91, %92, %93, %94, %95, " - " %96, %97, %98, %99, %100, %101, %102, %103, " - " %104, %105, %106, %107, %108, %109, %110, %111}," - "{%112, %113, %114, %115}," - " %116," - " p, %118, %119;\n" + "setp.ne.b32 p, %21, 0;\n" + "wgmma.mma_async.sync.aligned.m64n32k32.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + "{%16, %17, %18, %19}," + " %20," + " p;\n" "}\n" - : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), - "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), - "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), - "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), - "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), - "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), - "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), - "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), - "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), - "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), - "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), - "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), - "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), - "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), - "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), - "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), - "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), - "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), - "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), - "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), - "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), - "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), - "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), - "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), - "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), - "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), - "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), - "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111) - : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), "l"(desc_b), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); + "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x224x32_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x32x32_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x240x32 TN F16+=E4M3*E4M3 -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -struct MMA_64x240x32_F16E4M3E4M3_SS_TN +// GMMA 64x64x32 TN S32+=S8*U8 +struct MMA_64x64x32_S32S8U8_RS_TN { using DRegisters = void; - using ARegisters = uint64_t[1]; + using ARegisters = uint32_t[4]; using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[60]; + using CRegisters = uint32_t[32]; CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, uint64_t const& desc_b, uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, @@ -35594,33 +7966,22 @@ struct MMA_64x240x32_F16E4M3E4M3_SS_TN uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, - uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, - uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, - uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %62, 0;\n" - "wgmma.mma_async.sync.aligned.m64n240k32.f16.e4m3.e4m3 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59}," - " %60," - " %61," - " p, %63, %64;\n" + "setp.ne.b32 p, %37, 0;\n" + "wgmma.mma_async.sync.aligned.m64n64k32.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + "{%32, %33, %34, %35}," + " %36," + " p;\n" "}\n" : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), @@ -35629,38 +7990,25 @@ struct MMA_64x240x32_F16E4M3E4M3_SS_TN "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), - "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), - "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), - "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), - "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59) - : "l"(desc_a), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), "l"(desc_b), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); + "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x240x32_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x64x32_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x240x32 TN F16+=E4M3*E4M3 -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -struct MMA_64x240x32_F16E4M3E4M3_RS_TN +// GMMA 64x64x32 TN S32+=S8*U8 +struct MMA_64x64x32_S32S8U8_RS_TN_SATURATE { using DRegisters = void; using ARegisters = uint32_t[4]; using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[60]; + using CRegisters = uint32_t[32]; CUTE_HOST_DEVICE static void fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, @@ -35673,13 +8021,6 @@ struct MMA_64x240x32_F16E4M3E4M3_RS_TN uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, - uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, - uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, - uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) @@ -35687,19 +8028,15 @@ struct MMA_64x240x32_F16E4M3E4M3_RS_TN asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %65, 0;\n" - "wgmma.mma_async.sync.aligned.m64n240k32.f16.e4m3.e4m3 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59}," - "{%60, %61, %62, %63}," - " %64," - " p, %66, %67;\n" + "setp.ne.b32 p, %37, 0;\n" + "wgmma.mma_async.sync.aligned.m64n64k32.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + "{%32, %33, %34, %35}," + " %36," + " p;\n" "}\n" : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), @@ -35708,188 +8045,106 @@ struct MMA_64x240x32_F16E4M3E4M3_RS_TN "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), - "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), - "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), - "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), - "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59) + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) : "r"(a00), "r"(a01), "r"(a02), "r"(a03), "l"(desc_b), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); + "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x240x32_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x64x32_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x240x32 TN F32+=E4M3*E4M3 -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -struct MMA_64x240x32_F32E4M3E4M3_SS_TN +// GMMA 64x96x32 TN S32+=S8*U8 +struct MMA_64x96x32_S32S8U8_RS_TN { using DRegisters = void; - using ARegisters = uint64_t[1]; + using ARegisters = uint32_t[4]; using BRegisters = uint64_t[1]; - using CRegisters = float[120]; + using CRegisters = uint32_t[48]; CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, uint64_t const& desc_b, - float & d000, float & d001, float & d002, float & d003, - float & d004, float & d005, float & d006, float & d007, - float & d008, float & d009, float & d010, float & d011, - float & d012, float & d013, float & d014, float & d015, - float & d016, float & d017, float & d018, float & d019, - float & d020, float & d021, float & d022, float & d023, - float & d024, float & d025, float & d026, float & d027, - float & d028, float & d029, float & d030, float & d031, - float & d032, float & d033, float & d034, float & d035, - float & d036, float & d037, float & d038, float & d039, - float & d040, float & d041, float & d042, float & d043, - float & d044, float & d045, float & d046, float & d047, - float & d048, float & d049, float & d050, float & d051, - float & d052, float & d053, float & d054, float & d055, - float & d056, float & d057, float & d058, float & d059, - float & d060, float & d061, float & d062, float & d063, - float & d064, float & d065, float & d066, float & d067, - float & d068, float & d069, float & d070, float & d071, - float & d072, float & d073, float & d074, float & d075, - float & d076, float & d077, float & d078, float & d079, - float & d080, float & d081, float & d082, float & d083, - float & d084, float & d085, float & d086, float & d087, - float & d088, float & d089, float & d090, float & d091, - float & d092, float & d093, float & d094, float & d095, - float & d096, float & d097, float & d098, float & d099, - float & d100, float & d101, float & d102, float & d103, - float & d104, float & d105, float & d106, float & d107, - float & d108, float & d109, float & d110, float & d111, - float & d112, float & d113, float & d114, float & d115, - float & d116, float & d117, float & d118, float & d119, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %122, 0;\n" - "wgmma.mma_async.sync.aligned.m64n240k32.f32.e4m3.e4m3 " + "setp.ne.b32 p, %53, 0;\n" + "wgmma.mma_async.sync.aligned.m64n96k32.s32.s8.u8 " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " " %16, %17, %18, %19, %20, %21, %22, %23, " " %24, %25, %26, %27, %28, %29, %30, %31, " " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87, " - " %88, %89, %90, %91, %92, %93, %94, %95, " - " %96, %97, %98, %99, %100, %101, %102, %103, " - " %104, %105, %106, %107, %108, %109, %110, %111, " - " %112, %113, %114, %115, %116, %117, %118, %119}," - " %120," - " %121," - " p, %123, %124;\n" + " %40, %41, %42, %43, %44, %45, %46, %47}," + "{%48, %49, %50, %51}," + " %52," + " p;\n" "}\n" - : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), - "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), - "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), - "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), - "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), - "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), - "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), - "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), - "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), - "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), - "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), - "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), - "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), - "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), - "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), - "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), - "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), - "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), - "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), - "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), - "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), - "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), - "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), - "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), - "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), - "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), - "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), - "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), - "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), - "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119) - : "l"(desc_a), + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), "l"(desc_b), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); + "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x240x32_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x96x32_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x240x32 TN F32+=E4M3*E4M3 -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -struct MMA_64x240x32_F32E4M3E4M3_RS_TN +// GMMA 64x96x32 TN S32+=S8*U8 +struct MMA_64x96x32_S32S8U8_RS_TN_SATURATE { using DRegisters = void; using ARegisters = uint32_t[4]; using BRegisters = uint64_t[1]; - using CRegisters = float[120]; + using CRegisters = uint32_t[48]; CUTE_HOST_DEVICE static void - fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, uint64_t const& desc_b, - float & d000, float & d001, float & d002, float & d003, - float & d004, float & d005, float & d006, float & d007, - float & d008, float & d009, float & d010, float & d011, - float & d012, float & d013, float & d014, float & d015, - float & d016, float & d017, float & d018, float & d019, - float & d020, float & d021, float & d022, float & d023, - float & d024, float & d025, float & d026, float & d027, - float & d028, float & d029, float & d030, float & d031, - float & d032, float & d033, float & d034, float & d035, - float & d036, float & d037, float & d038, float & d039, - float & d040, float & d041, float & d042, float & d043, - float & d044, float & d045, float & d046, float & d047, - float & d048, float & d049, float & d050, float & d051, - float & d052, float & d053, float & d054, float & d055, - float & d056, float & d057, float & d058, float & d059, - float & d060, float & d061, float & d062, float & d063, - float & d064, float & d065, float & d066, float & d067, - float & d068, float & d069, float & d070, float & d071, - float & d072, float & d073, float & d074, float & d075, - float & d076, float & d077, float & d078, float & d079, - float & d080, float & d081, float & d082, float & d083, - float & d084, float & d085, float & d086, float & d087, - float & d088, float & d089, float & d090, float & d091, - float & d092, float & d093, float & d094, float & d095, - float & d096, float & d097, float & d098, float & d099, - float & d100, float & d101, float & d102, float & d103, - float & d104, float & d105, float & d106, float & d107, - float & d108, float & d109, float & d110, float & d111, - float & d112, float & d113, float & d114, float & d115, - float & d116, float & d117, float & d118, float & d119, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) @@ -35897,83 +8152,51 @@ struct MMA_64x240x32_F32E4M3E4M3_RS_TN asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %125, 0;\n" - "wgmma.mma_async.sync.aligned.m64n240k32.f32.e4m3.e4m3 " + "setp.ne.b32 p, %53, 0;\n" + "wgmma.mma_async.sync.aligned.m64n96k32.s32.s8.u8.satfinite " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " " %16, %17, %18, %19, %20, %21, %22, %23, " " %24, %25, %26, %27, %28, %29, %30, %31, " " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87, " - " %88, %89, %90, %91, %92, %93, %94, %95, " - " %96, %97, %98, %99, %100, %101, %102, %103, " - " %104, %105, %106, %107, %108, %109, %110, %111, " - " %112, %113, %114, %115, %116, %117, %118, %119}," - "{%120, %121, %122, %123}," - " %124," - " p, %126, %127;\n" + " %40, %41, %42, %43, %44, %45, %46, %47}," + "{%48, %49, %50, %51}," + " %52," + " p;\n" "}\n" - : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), - "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), - "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), - "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), - "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), - "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), - "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), - "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), - "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), - "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), - "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), - "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), - "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), - "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), - "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), - "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), - "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), - "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), - "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), - "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), - "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), - "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), - "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), - "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), - "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), - "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), - "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), - "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), - "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), - "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119) - : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), "l"(desc_b), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); + "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x240x32_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x96x32_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -// GMMA 64x256x32 TN F16+=E4M3*E4M3 -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -struct MMA_64x256x32_F16E4M3E4M3_SS_TN +// GMMA 64x128x32 TN S32+=S8*U8 +struct MMA_64x128x32_S32S8U8_RS_TN { using DRegisters = void; - using ARegisters = uint64_t[1]; + using ARegisters = uint32_t[4]; using BRegisters = uint64_t[1]; using CRegisters = uint32_t[64]; CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, uint64_t const& desc_b, uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, @@ -35994,12 +8217,12 @@ struct MMA_64x256x32_F16E4M3E4M3_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %66, 0;\n" - "wgmma.mma_async.sync.aligned.m64n256k32.f16.e4m3.e4m3 " + "setp.ne.b32 p, %69, 0;\n" + "wgmma.mma_async.sync.aligned.m64n128k32.s32.s8.u8 " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " " %16, %17, %18, %19, %20, %21, %22, %23, " @@ -36008,9 +8231,9 @@ struct MMA_64x256x32_F16E4M3E4M3_SS_TN " %40, %41, %42, %43, %44, %45, %46, %47, " " %48, %49, %50, %51, %52, %53, %54, %55, " " %56, %57, %58, %59, %60, %61, %62, %63}," - " %64," - " %65," - " p, %67, %68;\n" + "{%64, %65, %66, %67}," + " %68," + " p;\n" "}\n" : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), @@ -36028,23 +8251,19 @@ struct MMA_64x256x32_F16E4M3E4M3_SS_TN "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) - : "l"(desc_a), + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), "l"(desc_b), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); + "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x256x32_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x128x32_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// GMMA 64x256x32 TN F16+=E4M3*E4M3 -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -struct MMA_64x256x32_F16E4M3E4M3_RS_TN +// GMMA 64x128x32 TN S32+=S8*U8 +struct MMA_64x128x32_S32S8U8_RS_TN_SATURATE { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -36078,7 +8297,7 @@ struct MMA_64x256x32_F16E4M3E4M3_RS_TN "{\n" ".reg .pred p;\n" "setp.ne.b32 p, %69, 0;\n" - "wgmma.mma_async.sync.aligned.m64n256k32.f16.e4m3.e4m3 " + "wgmma.mma_async.sync.aligned.m64n128k32.s32.s8.u8.satfinite " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " " %16, %17, %18, %19, %20, %21, %22, %23, " @@ -36089,7 +8308,7 @@ struct MMA_64x256x32_F16E4M3E4M3_RS_TN " %56, %57, %58, %59, %60, %61, %62, %63}," "{%64, %65, %66, %67}," " %68," - " p, %70, %71;\n" + " p;\n" "}\n" : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), @@ -36109,71 +8328,59 @@ struct MMA_64x256x32_F16E4M3E4M3_RS_TN "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) : "r"(a00), "r"(a01), "r"(a02), "r"(a03), "l"(desc_b), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); + "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x256x32_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x128x32_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// GMMA 64x256x32 TN F32+=E4M3*E4M3 -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -struct MMA_64x256x32_F32E4M3E4M3_SS_TN +// GMMA 64x192x32 TN S32+=S8*U8 +struct MMA_64x192x32_S32S8U8_RS_TN { using DRegisters = void; - using ARegisters = uint64_t[1]; + using ARegisters = uint32_t[4]; using BRegisters = uint64_t[1]; - using CRegisters = float[128]; + using CRegisters = uint32_t[96]; CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, uint64_t const& desc_b, - float & d000, float & d001, float & d002, float & d003, - float & d004, float & d005, float & d006, float & d007, - float & d008, float & d009, float & d010, float & d011, - float & d012, float & d013, float & d014, float & d015, - float & d016, float & d017, float & d018, float & d019, - float & d020, float & d021, float & d022, float & d023, - float & d024, float & d025, float & d026, float & d027, - float & d028, float & d029, float & d030, float & d031, - float & d032, float & d033, float & d034, float & d035, - float & d036, float & d037, float & d038, float & d039, - float & d040, float & d041, float & d042, float & d043, - float & d044, float & d045, float & d046, float & d047, - float & d048, float & d049, float & d050, float & d051, - float & d052, float & d053, float & d054, float & d055, - float & d056, float & d057, float & d058, float & d059, - float & d060, float & d061, float & d062, float & d063, - float & d064, float & d065, float & d066, float & d067, - float & d068, float & d069, float & d070, float & d071, - float & d072, float & d073, float & d074, float & d075, - float & d076, float & d077, float & d078, float & d079, - float & d080, float & d081, float & d082, float & d083, - float & d084, float & d085, float & d086, float & d087, - float & d088, float & d089, float & d090, float & d091, - float & d092, float & d093, float & d094, float & d095, - float & d096, float & d097, float & d098, float & d099, - float & d100, float & d101, float & d102, float & d103, - float & d104, float & d105, float & d106, float & d107, - float & d108, float & d109, float & d110, float & d111, - float & d112, float & d113, float & d114, float & d115, - float & d116, float & d117, float & d118, float & d119, - float & d120, float & d121, float & d122, float & d123, - float & d124, float & d125, float & d126, float & d127, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + uint32_t & d88, uint32_t & d89, uint32_t & d90, uint32_t & d91, + uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %130, 0;\n" - "wgmma.mma_async.sync.aligned.m64n256k32.f32.e4m3.e4m3 " + "setp.ne.b32 p, %101, 0;\n" + "wgmma.mma_async.sync.aligned.m64n192k32.s32.s8.u8 " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " " %16, %17, %18, %19, %20, %21, %22, %23, " @@ -36185,105 +8392,81 @@ struct MMA_64x256x32_F32E4M3E4M3_SS_TN " %64, %65, %66, %67, %68, %69, %70, %71, " " %72, %73, %74, %75, %76, %77, %78, %79, " " %80, %81, %82, %83, %84, %85, %86, %87, " - " %88, %89, %90, %91, %92, %93, %94, %95, " - " %96, %97, %98, %99, %100, %101, %102, %103, " - " %104, %105, %106, %107, %108, %109, %110, %111, " - " %112, %113, %114, %115, %116, %117, %118, %119, " - " %120, %121, %122, %123, %124, %125, %126, %127}," - " %128," - " %129," - " p, %131, %132;\n" + " %88, %89, %90, %91, %92, %93, %94, %95}," + "{%96, %97, %98, %99}," + " %100," + " p;\n" "}\n" - : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), - "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), - "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), - "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), - "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), - "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), - "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), - "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), - "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), - "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), - "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), - "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), - "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), - "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), - "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), - "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), - "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), - "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), - "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), - "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), - "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), - "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), - "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), - "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), - "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), - "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), - "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), - "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), - "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), - "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), - "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123), - "+f"(d124), "+f"(d125), "+f"(d126), "+f"(d127) - : "l"(desc_a), + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87), + "+r"(d88), "+r"(d89), "+r"(d90), "+r"(d91), + "+r"(d92), "+r"(d93), "+r"(d94), "+r"(d95) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), "l"(desc_b), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); + "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x256x32_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x192x32_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// GMMA 64x256x32 TN F32+=E4M3*E4M3 -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -struct MMA_64x256x32_F32E4M3E4M3_RS_TN +// GMMA 64x192x32 TN S32+=S8*U8 +struct MMA_64x192x32_S32S8U8_RS_TN_SATURATE { using DRegisters = void; using ARegisters = uint32_t[4]; using BRegisters = uint64_t[1]; - using CRegisters = float[128]; + using CRegisters = uint32_t[96]; CUTE_HOST_DEVICE static void - fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, uint64_t const& desc_b, - float & d000, float & d001, float & d002, float & d003, - float & d004, float & d005, float & d006, float & d007, - float & d008, float & d009, float & d010, float & d011, - float & d012, float & d013, float & d014, float & d015, - float & d016, float & d017, float & d018, float & d019, - float & d020, float & d021, float & d022, float & d023, - float & d024, float & d025, float & d026, float & d027, - float & d028, float & d029, float & d030, float & d031, - float & d032, float & d033, float & d034, float & d035, - float & d036, float & d037, float & d038, float & d039, - float & d040, float & d041, float & d042, float & d043, - float & d044, float & d045, float & d046, float & d047, - float & d048, float & d049, float & d050, float & d051, - float & d052, float & d053, float & d054, float & d055, - float & d056, float & d057, float & d058, float & d059, - float & d060, float & d061, float & d062, float & d063, - float & d064, float & d065, float & d066, float & d067, - float & d068, float & d069, float & d070, float & d071, - float & d072, float & d073, float & d074, float & d075, - float & d076, float & d077, float & d078, float & d079, - float & d080, float & d081, float & d082, float & d083, - float & d084, float & d085, float & d086, float & d087, - float & d088, float & d089, float & d090, float & d091, - float & d092, float & d093, float & d094, float & d095, - float & d096, float & d097, float & d098, float & d099, - float & d100, float & d101, float & d102, float & d103, - float & d104, float & d105, float & d106, float & d107, - float & d108, float & d109, float & d110, float & d111, - float & d112, float & d113, float & d114, float & d115, - float & d116, float & d117, float & d118, float & d119, - float & d120, float & d121, float & d122, float & d123, - float & d124, float & d125, float & d126, float & d127, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + uint32_t & d88, uint32_t & d89, uint32_t & d90, uint32_t & d91, + uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) @@ -36291,8 +8474,8 @@ struct MMA_64x256x32_F32E4M3E4M3_RS_TN asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %133, 0;\n" - "wgmma.mma_async.sync.aligned.m64n256k32.f32.e4m3.e4m3 " + "setp.ne.b32 p, %101, 0;\n" + "wgmma.mma_async.sync.aligned.m64n192k32.s32.s8.u8.satfinite " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " " %16, %17, %18, %19, %20, %21, %22, %23, " @@ -36304,116 +8487,204 @@ struct MMA_64x256x32_F32E4M3E4M3_RS_TN " %64, %65, %66, %67, %68, %69, %70, %71, " " %72, %73, %74, %75, %76, %77, %78, %79, " " %80, %81, %82, %83, %84, %85, %86, %87, " - " %88, %89, %90, %91, %92, %93, %94, %95, " - " %96, %97, %98, %99, %100, %101, %102, %103, " - " %104, %105, %106, %107, %108, %109, %110, %111, " - " %112, %113, %114, %115, %116, %117, %118, %119, " - " %120, %121, %122, %123, %124, %125, %126, %127}," - "{%128, %129, %130, %131}," - " %132," - " p, %134, %135;\n" + " %88, %89, %90, %91, %92, %93, %94, %95}," + "{%96, %97, %98, %99}," + " %100," + " p;\n" "}\n" - : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), - "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), - "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), - "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), - "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), - "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), - "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), - "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), - "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), - "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), - "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), - "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), - "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), - "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), - "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), - "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), - "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), - "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), - "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), - "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), - "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), - "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), - "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), - "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), - "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), - "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), - "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), - "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), - "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), - "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), - "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123), - "+f"(d124), "+f"(d125), "+f"(d126), "+f"(d127) - : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87), + "+r"(d88), "+r"(d89), "+r"(d90), "+r"(d91), + "+r"(d92), "+r"(d93), "+r"(d94), "+r"(d95) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), "l"(desc_b), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); + "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x256x32_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x192x32_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// GMMA 64x8x32 TN F16+=E4M3*E5M2 -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -struct MMA_64x8x32_F16E4M3E5M2_SS_TN +// GMMA 64x256x32 TN S32+=S8*U8 +struct MMA_64x256x32_S32S8U8_RS_TN { using DRegisters = void; - using ARegisters = uint64_t[1]; + using ARegisters = uint32_t[4]; using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[2]; + using CRegisters = uint32_t[128]; CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, uint64_t const& desc_b, - uint32_t & d0, uint32_t & d1, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + uint32_t & d120, uint32_t & d121, uint32_t & d122, uint32_t & d123, + uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %4, 0;\n" - "wgmma.mma_async.sync.aligned.m64n8k32.f16.e4m3.e5m2 " - "{%0, %1}," - " %2," - " %3," - " p, %5, %6;\n" + "setp.ne.b32 p, %133, 0;\n" + "wgmma.mma_async.sync.aligned.m64n256k32.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + "{%128, %129, %130, %131}," + " %132," + " p;\n" "}\n" - : "+r"(d0), "+r"(d1) - : "l"(desc_a), + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119), + "+r"(d120), "+r"(d121), "+r"(d122), "+r"(d123), + "+r"(d124), "+r"(d125), "+r"(d126), "+r"(d127) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), "l"(desc_b), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); + "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x8x32_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x256x32_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// GMMA 64x8x32 TN F16+=E4M3*E5M2 -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -struct MMA_64x8x32_F16E4M3E5M2_RS_TN +// GMMA 64x256x32 TN S32+=S8*U8 +struct MMA_64x256x32_S32S8U8_RS_TN_SATURATE { using DRegisters = void; using ARegisters = uint32_t[4]; using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[2]; + using CRegisters = uint32_t[128]; CUTE_HOST_DEVICE static void - fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, uint64_t const& desc_b, - uint32_t & d0, uint32_t & d1, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + uint32_t & d120, uint32_t & d121, uint32_t & d122, uint32_t & d123, + uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) @@ -36421,41 +8692,83 @@ struct MMA_64x8x32_F16E4M3E5M2_RS_TN asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %7, 0;\n" - "wgmma.mma_async.sync.aligned.m64n8k32.f16.e4m3.e5m2 " - "{%0, %1}," - "{%2, %3, %4, %5}," - " %6," - " p, %8, %9;\n" + "setp.ne.b32 p, %133, 0;\n" + "wgmma.mma_async.sync.aligned.m64n256k32.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + "{%128, %129, %130, %131}," + " %132," + " p;\n" "}\n" - : "+r"(d0), "+r"(d1) - : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119), + "+r"(d120), "+r"(d121), "+r"(d122), "+r"(d123), + "+r"(d124), "+r"(d125), "+r"(d126), "+r"(d127) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), "l"(desc_b), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); + "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x8x32_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x256x32_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// GMMA 64x8x32 TN F32+=E4M3*E5M2 -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -struct MMA_64x8x32_F32E4M3E5M2_SS_TN +// GMMA 64x8x32 TN S32+=U8*S8 +struct MMA_64x8x32_S32U8S8_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = float[4]; + using CRegisters = uint32_t[4]; CUTE_HOST_DEVICE static void fma(uint64_t const& desc_a, uint64_t const& desc_b, - float & d0, float & d1, float & d2, float & d3, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) @@ -36464,82 +8777,75 @@ struct MMA_64x8x32_F32E4M3E5M2_SS_TN "{\n" ".reg .pred p;\n" "setp.ne.b32 p, %6, 0;\n" - "wgmma.mma_async.sync.aligned.m64n8k32.f32.e4m3.e5m2 " + "wgmma.mma_async.sync.aligned.m64n8k32.s32.u8.s8 " "{%0, %1, %2, %3}," " %4," " %5," - " p, %7, %8;\n" + " p;\n" "}\n" - : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3) + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) : "l"(desc_a), "l"(desc_b), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); + "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x8x32_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x8x32_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// GMMA 64x8x32 TN F32+=E4M3*E5M2 -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -struct MMA_64x8x32_F32E4M3E5M2_RS_TN +// GMMA 64x8x32 TN S32+=U8*S8 +struct MMA_64x8x32_S32U8S8_SS_TN_SATURATE { using DRegisters = void; - using ARegisters = uint32_t[4]; + using ARegisters = uint64_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = float[4]; + using CRegisters = uint32_t[4]; CUTE_HOST_DEVICE static void - fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + fma(uint64_t const& desc_a, uint64_t const& desc_b, - float & d0, float & d1, float & d2, float & d3, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %9, 0;\n" - "wgmma.mma_async.sync.aligned.m64n8k32.f32.e4m3.e5m2 " + "setp.ne.b32 p, %6, 0;\n" + "wgmma.mma_async.sync.aligned.m64n8k32.s32.u8.s8.satfinite " "{%0, %1, %2, %3}," - "{%4, %5, %6, %7}," - " %8," - " p, %10, %11;\n" + " %4," + " %5," + " p;\n" "}\n" - : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3) - : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) + : "l"(desc_a), "l"(desc_b), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); + "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x8x32_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x8x32_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// GMMA 64x16x32 TN F16+=E4M3*E5M2 -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -struct MMA_64x16x32_F16E4M3E5M2_SS_TN +// GMMA 64x16x32 TN S32+=U8*S8 +struct MMA_64x16x32_S32U8S8_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[4]; + using CRegisters = uint32_t[8]; CUTE_HOST_DEVICE static void fma(uint64_t const& desc_a, uint64_t const& desc_b, uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) @@ -36547,84 +8853,81 @@ struct MMA_64x16x32_F16E4M3E5M2_SS_TN asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %6, 0;\n" - "wgmma.mma_async.sync.aligned.m64n16k32.f16.e4m3.e5m2 " - "{%0, %1, %2, %3}," - " %4," - " %5," - " p, %7, %8;\n" + "setp.ne.b32 p, %10, 0;\n" + "wgmma.mma_async.sync.aligned.m64n16k32.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + " %8," + " %9," + " p;\n" "}\n" - : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) : "l"(desc_a), "l"(desc_b), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); + "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x16x32_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x16x32_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// GMMA 64x16x32 TN F16+=E4M3*E5M2 -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -struct MMA_64x16x32_F16E4M3E5M2_RS_TN +// GMMA 64x16x32 TN S32+=U8*S8 +struct MMA_64x16x32_S32U8S8_SS_TN_SATURATE { using DRegisters = void; - using ARegisters = uint32_t[4]; + using ARegisters = uint64_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[4]; + using CRegisters = uint32_t[8]; CUTE_HOST_DEVICE static void - fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + fma(uint64_t const& desc_a, uint64_t const& desc_b, uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %9, 0;\n" - "wgmma.mma_async.sync.aligned.m64n16k32.f16.e4m3.e5m2 " - "{%0, %1, %2, %3}," - "{%4, %5, %6, %7}," + "setp.ne.b32 p, %10, 0;\n" + "wgmma.mma_async.sync.aligned.m64n16k32.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7}," " %8," - " p, %10, %11;\n" + " %9," + " p;\n" "}\n" - : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) - : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) + : "l"(desc_a), "l"(desc_b), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); + "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x16x32_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x16x32_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// GMMA 64x16x32 TN F32+=E4M3*E5M2 -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -struct MMA_64x16x32_F32E4M3E5M2_SS_TN +// GMMA 64x32x32 TN S32+=U8*S8 +struct MMA_64x32x32_S32U8S8_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = float[8]; + using CRegisters = uint32_t[16]; CUTE_HOST_DEVICE static void fma(uint64_t const& desc_a, uint64_t const& desc_b, - float & d0, float & d1, float & d2, float & d3, - float & d4, float & d5, float & d6, float & d7, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) @@ -36632,87 +8935,93 @@ struct MMA_64x16x32_F32E4M3E5M2_SS_TN asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %10, 0;\n" - "wgmma.mma_async.sync.aligned.m64n16k32.f32.e4m3.e5m2 " - "{%0, %1, %2, %3, %4, %5, %6, %7}," - " %8," - " %9," - " p, %11, %12;\n" + "setp.ne.b32 p, %18, 0;\n" + "wgmma.mma_async.sync.aligned.m64n32k32.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + " %16," + " %17," + " p;\n" "}\n" - : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3), - "+f"(d4), "+f"(d5), "+f"(d6), "+f"(d7) + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) : "l"(desc_a), "l"(desc_b), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); + "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x16x32_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x32x32_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// GMMA 64x16x32 TN F32+=E4M3*E5M2 -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -struct MMA_64x16x32_F32E4M3E5M2_RS_TN +// GMMA 64x32x32 TN S32+=U8*S8 +struct MMA_64x32x32_S32U8S8_SS_TN_SATURATE { using DRegisters = void; - using ARegisters = uint32_t[4]; + using ARegisters = uint64_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = float[8]; + using CRegisters = uint32_t[16]; CUTE_HOST_DEVICE static void - fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + fma(uint64_t const& desc_a, uint64_t const& desc_b, - float & d0, float & d1, float & d2, float & d3, - float & d4, float & d5, float & d6, float & d7, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %13, 0;\n" - "wgmma.mma_async.sync.aligned.m64n16k32.f32.e4m3.e5m2 " - "{%0, %1, %2, %3, %4, %5, %6, %7}," - "{%8, %9, %10, %11}," - " %12," - " p, %14, %15;\n" + "setp.ne.b32 p, %18, 0;\n" + "wgmma.mma_async.sync.aligned.m64n32k32.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + " %16," + " %17," + " p;\n" "}\n" - : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3), - "+f"(d4), "+f"(d5), "+f"(d6), "+f"(d7) - : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) + : "l"(desc_a), "l"(desc_b), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); + "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x16x32_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x32x32_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// GMMA 64x32x32 TN F16+=E4M3*E5M2 -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -struct MMA_64x32x32_F16E4M3E5M2_SS_TN +// GMMA 64x64x32 TN S32+=U8*S8 +struct MMA_64x64x32_S32U8S8_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[8]; + using CRegisters = uint32_t[32]; CUTE_HOST_DEVICE static void fma(uint64_t const& desc_a, uint64_t const& desc_b, - uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, - uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) @@ -36720,89 +9029,113 @@ struct MMA_64x32x32_F16E4M3E5M2_SS_TN asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %10, 0;\n" - "wgmma.mma_async.sync.aligned.m64n32k32.f16.e4m3.e5m2 " - "{%0, %1, %2, %3, %4, %5, %6, %7}," - " %8," - " %9," - " p, %11, %12;\n" + "setp.ne.b32 p, %34, 0;\n" + "wgmma.mma_async.sync.aligned.m64n64k32.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + " %32," + " %33," + " p;\n" "}\n" - : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), - "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) : "l"(desc_a), "l"(desc_b), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); + "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x32x32_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x64x32_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// GMMA 64x32x32 TN F16+=E4M3*E5M2 -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -struct MMA_64x32x32_F16E4M3E5M2_RS_TN +// GMMA 64x64x32 TN S32+=U8*S8 +struct MMA_64x64x32_S32U8S8_SS_TN_SATURATE { using DRegisters = void; - using ARegisters = uint32_t[4]; + using ARegisters = uint64_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[8]; + using CRegisters = uint32_t[32]; CUTE_HOST_DEVICE static void - fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + fma(uint64_t const& desc_a, uint64_t const& desc_b, - uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, - uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %13, 0;\n" - "wgmma.mma_async.sync.aligned.m64n32k32.f16.e4m3.e5m2 " - "{%0, %1, %2, %3, %4, %5, %6, %7}," - "{%8, %9, %10, %11}," - " %12," - " p, %14, %15;\n" + "setp.ne.b32 p, %34, 0;\n" + "wgmma.mma_async.sync.aligned.m64n64k32.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + " %32," + " %33," + " p;\n" "}\n" - : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), - "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) - : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) + : "l"(desc_a), "l"(desc_b), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); + "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x32x32_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x64x32_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// GMMA 64x32x32 TN F32+=E4M3*E5M2 -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -struct MMA_64x32x32_F32E4M3E5M2_SS_TN +// GMMA 64x96x32 TN S32+=U8*S8 +struct MMA_64x96x32_S32U8S8_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = float[16]; + using CRegisters = uint32_t[48]; CUTE_HOST_DEVICE static void fma(uint64_t const& desc_a, uint64_t const& desc_b, - float & d00, float & d01, float & d02, float & d03, - float & d04, float & d05, float & d06, float & d07, - float & d08, float & d09, float & d10, float & d11, - float & d12, float & d13, float & d14, float & d15, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) @@ -36810,90 +9143,113 @@ struct MMA_64x32x32_F32E4M3E5M2_SS_TN asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %18, 0;\n" - "wgmma.mma_async.sync.aligned.m64n32k32.f32.e4m3.e5m2 " + "setp.ne.b32 p, %50, 0;\n" + "wgmma.mma_async.sync.aligned.m64n96k32.s32.u8.s8 " "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15}," - " %16," - " %17," - " p, %19, %20;\n" + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + " %48," + " %49," + " p;\n" "}\n" - : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), - "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), - "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), - "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15) + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) : "l"(desc_a), "l"(desc_b), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); + "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x32x32_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x96x32_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// GMMA 64x32x32 TN F32+=E4M3*E5M2 -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -struct MMA_64x32x32_F32E4M3E5M2_RS_TN +// GMMA 64x96x32 TN S32+=U8*S8 +struct MMA_64x96x32_S32U8S8_SS_TN_SATURATE { using DRegisters = void; - using ARegisters = uint32_t[4]; + using ARegisters = uint64_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = float[16]; + using CRegisters = uint32_t[48]; CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + fma(uint64_t const& desc_a, uint64_t const& desc_b, - float & d00, float & d01, float & d02, float & d03, - float & d04, float & d05, float & d06, float & d07, - float & d08, float & d09, float & d10, float & d11, - float & d12, float & d13, float & d14, float & d15, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %21, 0;\n" - "wgmma.mma_async.sync.aligned.m64n32k32.f32.e4m3.e5m2 " + "setp.ne.b32 p, %50, 0;\n" + "wgmma.mma_async.sync.aligned.m64n96k32.s32.u8.s8.satfinite " "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15}," - "{%16, %17, %18, %19}," - " %20," - " p, %22, %23;\n" + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + " %48," + " %49," + " p;\n" "}\n" - : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), - "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), - "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), - "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) + : "l"(desc_a), "l"(desc_b), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); + "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x32x32_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x96x32_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x48x32 TN F16+=E4M3*E5M2 -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -struct MMA_64x48x32_F16E4M3E5M2_SS_TN +// GMMA 64x128x32 TN S32+=U8*S8 +struct MMA_64x128x32_S32U8S8_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[12]; + using CRegisters = uint32_t[64]; CUTE_HOST_DEVICE static void fma(uint64_t const& desc_a, @@ -36901,6 +9257,19 @@ struct MMA_64x48x32_F16E4M3E5M2_SS_TN uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) @@ -36908,100 +9277,157 @@ struct MMA_64x48x32_F16E4M3E5M2_SS_TN asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %14, 0;\n" - "wgmma.mma_async.sync.aligned.m64n48k32.f16.e4m3.e5m2 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11}," - " %12," - " %13," - " p, %15, %16;\n" + "setp.ne.b32 p, %66, 0;\n" + "wgmma.mma_async.sync.aligned.m64n128k32.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + " %64," + " %65," + " p;\n" "}\n" : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11) + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) : "l"(desc_a), "l"(desc_b), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); + "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x48x32_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x128x32_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x48x32 TN F16+=E4M3*E5M2 -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -struct MMA_64x48x32_F16E4M3E5M2_RS_TN +// GMMA 64x128x32 TN S32+=U8*S8 +struct MMA_64x128x32_S32U8S8_SS_TN_SATURATE { using DRegisters = void; - using ARegisters = uint32_t[4]; + using ARegisters = uint64_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[12]; + using CRegisters = uint32_t[64]; CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + fma(uint64_t const& desc_a, uint64_t const& desc_b, uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %17, 0;\n" - "wgmma.mma_async.sync.aligned.m64n48k32.f16.e4m3.e5m2 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11}," - "{%12, %13, %14, %15}," - " %16," - " p, %18, %19;\n" + "setp.ne.b32 p, %66, 0;\n" + "wgmma.mma_async.sync.aligned.m64n128k32.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + " %64," + " %65," + " p;\n" "}\n" : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) + : "l"(desc_a), "l"(desc_b), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); + "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x48x32_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x128x32_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x48x32 TN F32+=E4M3*E5M2 -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -struct MMA_64x48x32_F32E4M3E5M2_SS_TN +// GMMA 64x192x32 TN S32+=U8*S8 +struct MMA_64x192x32_S32U8S8_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = float[24]; + using CRegisters = uint32_t[96]; CUTE_HOST_DEVICE static void fma(uint64_t const& desc_a, uint64_t const& desc_b, - float & d00, float & d01, float & d02, float & d03, - float & d04, float & d05, float & d06, float & d07, - float & d08, float & d09, float & d10, float & d11, - float & d12, float & d13, float & d14, float & d15, - float & d16, float & d17, float & d18, float & d19, - float & d20, float & d21, float & d22, float & d23, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + uint32_t & d88, uint32_t & d89, uint32_t & d90, uint32_t & d91, + uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) @@ -37009,100 +9435,66 @@ struct MMA_64x48x32_F32E4M3E5M2_SS_TN asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %26, 0;\n" - "wgmma.mma_async.sync.aligned.m64n48k32.f32.e4m3.e5m2 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23}," - " %24," - " %25," - " p, %27, %28;\n" + "setp.ne.b32 p, %98, 0;\n" + "wgmma.mma_async.sync.aligned.m64n192k32.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + " %96," + " %97," + " p;\n" "}\n" - : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), - "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), - "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), - "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), - "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), - "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23) + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87), + "+r"(d88), "+r"(d89), "+r"(d90), "+r"(d91), + "+r"(d92), "+r"(d93), "+r"(d94), "+r"(d95) : "l"(desc_a), "l"(desc_b), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x48x32_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x48x32 TN F32+=E4M3*E5M2 -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -struct MMA_64x48x32_F32E4M3E5M2_RS_TN -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using BRegisters = uint64_t[1]; - using CRegisters = float[24]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, - uint64_t const& desc_b, - float & d00, float & d01, float & d02, float & d03, - float & d04, float & d05, float & d06, float & d07, - float & d08, float & d09, float & d10, float & d11, - float & d12, float & d13, float & d14, float & d15, - float & d16, float & d17, float & d18, float & d19, - float & d20, float & d21, float & d22, float & d23, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %29, 0;\n" - "wgmma.mma_async.sync.aligned.m64n48k32.f32.e4m3.e5m2 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23}," - "{%24, %25, %26, %27}," - " %28," - " p, %30, %31;\n" - "}\n" - : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), - "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), - "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), - "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), - "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), - "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), - "l"(desc_b), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); + "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x48x32_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x192x32_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -// GMMA 64x64x32 TN F16+=E4M3*E5M2 -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -struct MMA_64x64x32_F16E4M3E5M2_SS_TN +// GMMA 64x192x32 TN S32+=U8*S8 +struct MMA_64x192x32_S32U8S8_SS_TN_SATURATE { using DRegisters = void; using ARegisters = uint64_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[16]; + using CRegisters = uint32_t[96]; CUTE_HOST_DEVICE static void fma(uint64_t const& desc_a, @@ -37111,6 +9503,26 @@ struct MMA_64x64x32_F16E4M3E5M2_SS_TN uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + uint32_t & d88, uint32_t & d89, uint32_t & d90, uint32_t & d91, + uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) @@ -37118,101 +9530,217 @@ struct MMA_64x64x32_F16E4M3E5M2_SS_TN asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %18, 0;\n" - "wgmma.mma_async.sync.aligned.m64n64k32.f16.e4m3.e5m2 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15}," - " %16," - " %17," - " p, %19, %20;\n" + "setp.ne.b32 p, %98, 0;\n" + "wgmma.mma_async.sync.aligned.m64n192k32.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + " %96," + " %97," + " p;\n" "}\n" : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87), + "+r"(d88), "+r"(d89), "+r"(d90), "+r"(d91), + "+r"(d92), "+r"(d93), "+r"(d94), "+r"(d95) : "l"(desc_a), "l"(desc_b), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); + "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x64x32_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x192x32_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// GMMA 64x64x32 TN F16+=E4M3*E5M2 -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -struct MMA_64x64x32_F16E4M3E5M2_RS_TN +// GMMA 64x256x32 TN S32+=U8*S8 +struct MMA_64x256x32_S32U8S8_SS_TN { using DRegisters = void; - using ARegisters = uint32_t[4]; + using ARegisters = uint64_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[16]; + using CRegisters = uint32_t[128]; CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + fma(uint64_t const& desc_a, uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + uint32_t & d120, uint32_t & d121, uint32_t & d122, uint32_t & d123, + uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %21, 0;\n" - "wgmma.mma_async.sync.aligned.m64n64k32.f16.e4m3.e5m2 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15}," - "{%16, %17, %18, %19}," - " %20," - " p, %22, %23;\n" + "setp.ne.b32 p, %130, 0;\n" + "wgmma.mma_async.sync.aligned.m64n256k32.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + " %128," + " %129," + " p;\n" "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119), + "+r"(d120), "+r"(d121), "+r"(d122), "+r"(d123), + "+r"(d124), "+r"(d125), "+r"(d126), "+r"(d127) + : "l"(desc_a), "l"(desc_b), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); + "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x64x32_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x256x32_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// GMMA 64x64x32 TN F32+=E4M3*E5M2 -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -struct MMA_64x64x32_F32E4M3E5M2_SS_TN +// GMMA 64x256x32 TN S32+=U8*S8 +struct MMA_64x256x32_S32U8S8_SS_TN_SATURATE { using DRegisters = void; using ARegisters = uint64_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = float[32]; + using CRegisters = uint32_t[128]; CUTE_HOST_DEVICE static void fma(uint64_t const& desc_a, uint64_t const& desc_b, - float & d00, float & d01, float & d02, float & d03, - float & d04, float & d05, float & d06, float & d07, - float & d08, float & d09, float & d10, float & d11, - float & d12, float & d13, float & d14, float & d15, - float & d16, float & d17, float & d18, float & d19, - float & d20, float & d21, float & d22, float & d23, - float & d24, float & d25, float & d26, float & d27, - float & d28, float & d29, float & d30, float & d31, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + uint32_t & d120, uint32_t & d121, uint32_t & d122, uint32_t & d123, + uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) @@ -37220,58 +9748,83 @@ struct MMA_64x64x32_F32E4M3E5M2_SS_TN asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %34, 0;\n" - "wgmma.mma_async.sync.aligned.m64n64k32.f32.e4m3.e5m2 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31}," - " %32," - " %33," - " p, %35, %36;\n" + "setp.ne.b32 p, %130, 0;\n" + "wgmma.mma_async.sync.aligned.m64n256k32.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + " %128," + " %129," + " p;\n" "}\n" - : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), - "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), - "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), - "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), - "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), - "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), - "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), - "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31) + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119), + "+r"(d120), "+r"(d121), "+r"(d122), "+r"(d123), + "+r"(d124), "+r"(d125), "+r"(d126), "+r"(d127) : "l"(desc_a), "l"(desc_b), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); + "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x64x32_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x256x32_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// GMMA 64x64x32 TN F32+=E4M3*E5M2 -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -struct MMA_64x64x32_F32E4M3E5M2_RS_TN +// GMMA 64x8x32 TN S32+=U8*S8 +struct MMA_64x8x32_S32U8S8_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; using BRegisters = uint64_t[1]; - using CRegisters = float[32]; + using CRegisters = uint32_t[4]; CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, uint64_t const& desc_b, - float & d00, float & d01, float & d02, float & d03, - float & d04, float & d05, float & d06, float & d07, - float & d08, float & d09, float & d10, float & d11, - float & d12, float & d13, float & d14, float & d15, - float & d16, float & d17, float & d18, float & d19, - float & d20, float & d21, float & d22, float & d23, - float & d24, float & d25, float & d26, float & d27, - float & d28, float & d29, float & d30, float & d31, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) @@ -37279,110 +9832,76 @@ struct MMA_64x64x32_F32E4M3E5M2_RS_TN asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %37, 0;\n" - "wgmma.mma_async.sync.aligned.m64n64k32.f32.e4m3.e5m2 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31}," - "{%32, %33, %34, %35}," - " %36," - " p, %38, %39;\n" + "setp.ne.b32 p, %9, 0;\n" + "wgmma.mma_async.sync.aligned.m64n8k32.s32.u8.s8 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + " %8," + " p;\n" "}\n" - : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), - "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), - "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), - "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), - "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), - "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), - "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), - "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), "l"(desc_b), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); + "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x64x32_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x8x32_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x80x32 TN F16+=E4M3*E5M2 -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -struct MMA_64x80x32_F16E4M3E5M2_SS_TN +// GMMA 64x8x32 TN S32+=U8*S8 +struct MMA_64x8x32_S32U8S8_RS_TN_SATURATE { using DRegisters = void; - using ARegisters = uint64_t[1]; + using ARegisters = uint32_t[4]; using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[20]; + using CRegisters = uint32_t[4]; CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %22, 0;\n" - "wgmma.mma_async.sync.aligned.m64n80k32.f16.e4m3.e5m2 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19}," - " %20," - " %21," - " p, %23, %24;\n" + "setp.ne.b32 p, %9, 0;\n" + "wgmma.mma_async.sync.aligned.m64n8k32.s32.u8.s8.satfinite " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + " %8," + " p;\n" "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19) - : "l"(desc_a), + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), "l"(desc_b), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); + "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x80x32_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x8x32_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x80x32 TN F16+=E4M3*E5M2 -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -struct MMA_64x80x32_F16E4M3E5M2_RS_TN +// GMMA 64x16x32 TN S32+=U8*S8 +struct MMA_64x16x32_S32U8S8_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[20]; + using CRegisters = uint32_t[8]; CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) @@ -37390,124 +9909,81 @@ struct MMA_64x80x32_F16E4M3E5M2_RS_TN asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %25, 0;\n" - "wgmma.mma_async.sync.aligned.m64n80k32.f16.e4m3.e5m2 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19}," - "{%20, %21, %22, %23}," - " %24," - " p, %26, %27;\n" + "setp.ne.b32 p, %13, 0;\n" + "wgmma.mma_async.sync.aligned.m64n16k32.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + "{%8, %9, %10, %11}," + " %12," + " p;\n" "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), "l"(desc_b), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); + "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x80x32_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x16x32_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x80x32 TN F32+=E4M3*E5M2 -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -struct MMA_64x80x32_F32E4M3E5M2_SS_TN +// GMMA 64x16x32 TN S32+=U8*S8 +struct MMA_64x16x32_S32U8S8_RS_TN_SATURATE { using DRegisters = void; - using ARegisters = uint64_t[1]; + using ARegisters = uint32_t[4]; using BRegisters = uint64_t[1]; - using CRegisters = float[40]; + using CRegisters = uint32_t[8]; CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, uint64_t const& desc_b, - float & d00, float & d01, float & d02, float & d03, - float & d04, float & d05, float & d06, float & d07, - float & d08, float & d09, float & d10, float & d11, - float & d12, float & d13, float & d14, float & d15, - float & d16, float & d17, float & d18, float & d19, - float & d20, float & d21, float & d22, float & d23, - float & d24, float & d25, float & d26, float & d27, - float & d28, float & d29, float & d30, float & d31, - float & d32, float & d33, float & d34, float & d35, - float & d36, float & d37, float & d38, float & d39, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %42, 0;\n" - "wgmma.mma_async.sync.aligned.m64n80k32.f32.e4m3.e5m2 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39}," - " %40," - " %41," - " p, %43, %44;\n" + "setp.ne.b32 p, %13, 0;\n" + "wgmma.mma_async.sync.aligned.m64n16k32.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + "{%8, %9, %10, %11}," + " %12," + " p;\n" "}\n" - : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), - "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), - "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), - "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), - "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), - "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), - "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), - "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), - "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), - "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39) - : "l"(desc_a), + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), "l"(desc_b), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); + "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x80x32_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x16x32_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x80x32 TN F32+=E4M3*E5M2 -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -struct MMA_64x80x32_F32E4M3E5M2_RS_TN +// GMMA 64x32x32 TN S32+=U8*S8 +struct MMA_64x32x32_S32U8S8_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; using BRegisters = uint64_t[1]; - using CRegisters = float[40]; + using CRegisters = uint32_t[16]; CUTE_HOST_DEVICE static void fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, uint64_t const& desc_b, - float & d00, float & d01, float & d02, float & d03, - float & d04, float & d05, float & d06, float & d07, - float & d08, float & d09, float & d10, float & d11, - float & d12, float & d13, float & d14, float & d15, - float & d16, float & d17, float & d18, float & d19, - float & d20, float & d21, float & d22, float & d23, - float & d24, float & d25, float & d26, float & d27, - float & d28, float & d29, float & d30, float & d31, - float & d32, float & d33, float & d34, float & d35, - float & d36, float & d37, float & d38, float & d39, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) @@ -37515,104 +9991,81 @@ struct MMA_64x80x32_F32E4M3E5M2_RS_TN asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %45, 0;\n" - "wgmma.mma_async.sync.aligned.m64n80k32.f32.e4m3.e5m2 " + "setp.ne.b32 p, %21, 0;\n" + "wgmma.mma_async.sync.aligned.m64n32k32.s32.u8.s8 " "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39}," - "{%40, %41, %42, %43}," - " %44," - " p, %46, %47;\n" + " %8, %9, %10, %11, %12, %13, %14, %15}," + "{%16, %17, %18, %19}," + " %20," + " p;\n" "}\n" - : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), - "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), - "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), - "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), - "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), - "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), - "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), - "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), - "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), - "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39) + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) : "r"(a00), "r"(a01), "r"(a02), "r"(a03), "l"(desc_b), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); + "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x80x32_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x32x32_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -// GMMA 64x96x32 TN F16+=E4M3*E5M2 -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -struct MMA_64x96x32_F16E4M3E5M2_SS_TN +// GMMA 64x32x32 TN S32+=U8*S8 +struct MMA_64x32x32_S32U8S8_RS_TN_SATURATE { using DRegisters = void; - using ARegisters = uint64_t[1]; + using ARegisters = uint32_t[4]; using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[24]; + using CRegisters = uint32_t[16]; CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, uint64_t const& desc_b, uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %26, 0;\n" - "wgmma.mma_async.sync.aligned.m64n96k32.f16.e4m3.e5m2 " + "setp.ne.b32 p, %21, 0;\n" + "wgmma.mma_async.sync.aligned.m64n32k32.s32.u8.s8.satfinite " "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23}," - " %24," - " %25," - " p, %27, %28;\n" + " %8, %9, %10, %11, %12, %13, %14, %15}," + "{%16, %17, %18, %19}," + " %20," + " p;\n" "}\n" : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) - : "l"(desc_a), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), "l"(desc_b), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); + "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x96x32_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x32x32_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// GMMA 64x96x32 TN F16+=E4M3*E5M2 -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -struct MMA_64x96x32_F16E4M3E5M2_RS_TN +// GMMA 64x64x32 TN S32+=U8*S8 +struct MMA_64x64x32_S32U8S8_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[24]; + using CRegisters = uint32_t[32]; CUTE_HOST_DEVICE static void fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, @@ -37623,6 +10076,8 @@ struct MMA_64x96x32_F16E4M3E5M2_RS_TN uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) @@ -37630,128 +10085,54 @@ struct MMA_64x96x32_F16E4M3E5M2_RS_TN asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %29, 0;\n" - "wgmma.mma_async.sync.aligned.m64n96k32.f16.e4m3.e5m2 " + "setp.ne.b32 p, %37, 0;\n" + "wgmma.mma_async.sync.aligned.m64n64k32.s32.u8.s8 " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23}," - "{%24, %25, %26, %27}," - " %28," - " p, %30, %31;\n" + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + "{%32, %33, %34, %35}," + " %36," + " p;\n" "}\n" : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) : "r"(a00), "r"(a01), "r"(a02), "r"(a03), "l"(desc_b), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x96x32_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// GMMA 64x96x32 TN F32+=E4M3*E5M2 -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -struct MMA_64x96x32_F32E4M3E5M2_SS_TN -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = float[48]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - float & d00, float & d01, float & d02, float & d03, - float & d04, float & d05, float & d06, float & d07, - float & d08, float & d09, float & d10, float & d11, - float & d12, float & d13, float & d14, float & d15, - float & d16, float & d17, float & d18, float & d19, - float & d20, float & d21, float & d22, float & d23, - float & d24, float & d25, float & d26, float & d27, - float & d28, float & d29, float & d30, float & d31, - float & d32, float & d33, float & d34, float & d35, - float & d36, float & d37, float & d38, float & d39, - float & d40, float & d41, float & d42, float & d43, - float & d44, float & d45, float & d46, float & d47, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %50, 0;\n" - "wgmma.mma_async.sync.aligned.m64n96k32.f32.e4m3.e5m2 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47}," - " %48," - " %49," - " p, %51, %52;\n" - "}\n" - : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), - "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), - "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), - "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), - "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), - "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), - "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), - "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), - "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), - "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), - "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), - "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47) - : "l"(desc_a), - "l"(desc_b), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); + "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x96x32_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x64x32_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// GMMA 64x96x32 TN F32+=E4M3*E5M2 -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -struct MMA_64x96x32_F32E4M3E5M2_RS_TN +// GMMA 64x64x32 TN S32+=U8*S8 +struct MMA_64x64x32_S32U8S8_RS_TN_SATURATE { using DRegisters = void; using ARegisters = uint32_t[4]; using BRegisters = uint64_t[1]; - using CRegisters = float[48]; + using CRegisters = uint32_t[32]; CUTE_HOST_DEVICE static void fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, uint64_t const& desc_b, - float & d00, float & d01, float & d02, float & d03, - float & d04, float & d05, float & d06, float & d07, - float & d08, float & d09, float & d10, float & d11, - float & d12, float & d13, float & d14, float & d15, - float & d16, float & d17, float & d18, float & d19, - float & d20, float & d21, float & d22, float & d23, - float & d24, float & d25, float & d26, float & d27, - float & d28, float & d29, float & d30, float & d31, - float & d32, float & d33, float & d34, float & d35, - float & d36, float & d37, float & d38, float & d39, - float & d40, float & d41, float & d42, float & d43, - float & d44, float & d45, float & d46, float & d47, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) @@ -37759,56 +10140,45 @@ struct MMA_64x96x32_F32E4M3E5M2_RS_TN asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %53, 0;\n" - "wgmma.mma_async.sync.aligned.m64n96k32.f32.e4m3.e5m2 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47}," - "{%48, %49, %50, %51}," - " %52," - " p, %54, %55;\n" + "setp.ne.b32 p, %37, 0;\n" + "wgmma.mma_async.sync.aligned.m64n64k32.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + "{%32, %33, %34, %35}," + " %36," + " p;\n" "}\n" - : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), - "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), - "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), - "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), - "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), - "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), - "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), - "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), - "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), - "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), - "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), - "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47) + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) : "r"(a00), "r"(a01), "r"(a02), "r"(a03), "l"(desc_b), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); + "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x96x32_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x64x32_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x112x32 TN F16+=E4M3*E5M2 -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -struct MMA_64x112x32_F16E4M3E5M2_SS_TN +// GMMA 64x96x32 TN S32+=U8*S8 +struct MMA_64x96x32_S32U8S8_RS_TN { using DRegisters = void; - using ARegisters = uint64_t[1]; + using ARegisters = uint32_t[4]; using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[28]; + using CRegisters = uint32_t[48]; CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, uint64_t const& desc_b, uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, @@ -37817,22 +10187,29 @@ struct MMA_64x112x32_F16E4M3E5M2_SS_TN uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %30, 0;\n" - "wgmma.mma_async.sync.aligned.m64n112k32.f16.e4m3.e5m2 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27}," - " %28," - " %29," - " p, %31, %32;\n" + "setp.ne.b32 p, %53, 0;\n" + "wgmma.mma_async.sync.aligned.m64n96k32.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + "{%48, %49, %50, %51}," + " %52," + " p;\n" "}\n" : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), @@ -37840,31 +10217,30 @@ struct MMA_64x112x32_F16E4M3E5M2_SS_TN "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27) - : "l"(desc_a), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), "l"(desc_b), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); + "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x112x32_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x96x32_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x112x32 TN F16+=E4M3*E5M2 -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -struct MMA_64x112x32_F16E4M3E5M2_RS_TN +// GMMA 64x96x32 TN S32+=U8*S8 +struct MMA_64x96x32_S32U8S8_RS_TN_SATURATE { using DRegisters = void; using ARegisters = uint32_t[4]; using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[28]; + using CRegisters = uint32_t[48]; CUTE_HOST_DEVICE static void fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, @@ -37876,6 +10252,11 @@ struct MMA_64x112x32_F16E4M3E5M2_RS_TN uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) @@ -37883,15 +10264,17 @@ struct MMA_64x112x32_F16E4M3E5M2_RS_TN asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %33, 0;\n" - "wgmma.mma_async.sync.aligned.m64n112k32.f16.e4m3.e5m2 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27}," - "{%28, %29, %30, %31}," - " %32," - " p, %34, %35;\n" + "setp.ne.b32 p, %53, 0;\n" + "wgmma.mma_async.sync.aligned.m64n96k32.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + "{%48, %49, %50, %51}," + " %52," + " p;\n" "}\n" : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), @@ -37899,125 +10282,125 @@ struct MMA_64x112x32_F16E4M3E5M2_RS_TN "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27) + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) : "r"(a00), "r"(a01), "r"(a02), "r"(a03), "l"(desc_b), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); + "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x112x32_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x96x32_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x112x32 TN F32+=E4M3*E5M2 -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -struct MMA_64x112x32_F32E4M3E5M2_SS_TN +// GMMA 64x128x32 TN S32+=U8*S8 +struct MMA_64x128x32_S32U8S8_RS_TN { using DRegisters = void; - using ARegisters = uint64_t[1]; + using ARegisters = uint32_t[4]; using BRegisters = uint64_t[1]; - using CRegisters = float[56]; + using CRegisters = uint32_t[64]; CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, uint64_t const& desc_b, - float & d00, float & d01, float & d02, float & d03, - float & d04, float & d05, float & d06, float & d07, - float & d08, float & d09, float & d10, float & d11, - float & d12, float & d13, float & d14, float & d15, - float & d16, float & d17, float & d18, float & d19, - float & d20, float & d21, float & d22, float & d23, - float & d24, float & d25, float & d26, float & d27, - float & d28, float & d29, float & d30, float & d31, - float & d32, float & d33, float & d34, float & d35, - float & d36, float & d37, float & d38, float & d39, - float & d40, float & d41, float & d42, float & d43, - float & d44, float & d45, float & d46, float & d47, - float & d48, float & d49, float & d50, float & d51, - float & d52, float & d53, float & d54, float & d55, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %58, 0;\n" - "wgmma.mma_async.sync.aligned.m64n112k32.f32.e4m3.e5m2 " + "setp.ne.b32 p, %69, 0;\n" + "wgmma.mma_async.sync.aligned.m64n128k32.s32.u8.s8 " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " " %16, %17, %18, %19, %20, %21, %22, %23, " " %24, %25, %26, %27, %28, %29, %30, %31, " " %32, %33, %34, %35, %36, %37, %38, %39, " " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55}," - " %56," - " %57," - " p, %59, %60;\n" + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + "{%64, %65, %66, %67}," + " %68," + " p;\n" "}\n" - : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), - "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), - "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), - "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), - "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), - "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), - "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), - "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), - "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), - "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), - "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), - "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), - "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), - "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55) - : "l"(desc_a), + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), "l"(desc_b), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); + "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x112x32_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x128x32_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x112x32 TN F32+=E4M3*E5M2 -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -struct MMA_64x112x32_F32E4M3E5M2_RS_TN +// GMMA 64x128x32 TN S32+=U8*S8 +struct MMA_64x128x32_S32U8S8_RS_TN_SATURATE { using DRegisters = void; using ARegisters = uint32_t[4]; using BRegisters = uint64_t[1]; - using CRegisters = float[56]; + using CRegisters = uint32_t[64]; CUTE_HOST_DEVICE static void fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, uint64_t const& desc_b, - float & d00, float & d01, float & d02, float & d03, - float & d04, float & d05, float & d06, float & d07, - float & d08, float & d09, float & d10, float & d11, - float & d12, float & d13, float & d14, float & d15, - float & d16, float & d17, float & d18, float & d19, - float & d20, float & d21, float & d22, float & d23, - float & d24, float & d25, float & d26, float & d27, - float & d28, float & d29, float & d30, float & d31, - float & d32, float & d33, float & d34, float & d35, - float & d36, float & d37, float & d38, float & d39, - float & d40, float & d41, float & d42, float & d43, - float & d44, float & d45, float & d46, float & d47, - float & d48, float & d49, float & d50, float & d51, - float & d52, float & d53, float & d54, float & d55, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) @@ -38025,59 +10408,57 @@ struct MMA_64x112x32_F32E4M3E5M2_RS_TN asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %61, 0;\n" - "wgmma.mma_async.sync.aligned.m64n112k32.f32.e4m3.e5m2 " + "setp.ne.b32 p, %69, 0;\n" + "wgmma.mma_async.sync.aligned.m64n128k32.s32.u8.s8.satfinite " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " " %16, %17, %18, %19, %20, %21, %22, %23, " " %24, %25, %26, %27, %28, %29, %30, %31, " " %32, %33, %34, %35, %36, %37, %38, %39, " " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55}," - "{%56, %57, %58, %59}," - " %60," - " p, %62, %63;\n" + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + "{%64, %65, %66, %67}," + " %68," + " p;\n" "}\n" - : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), - "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), - "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), - "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), - "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), - "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), - "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), - "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), - "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), - "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), - "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), - "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), - "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), - "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55) + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) : "r"(a00), "r"(a01), "r"(a02), "r"(a03), "l"(desc_b), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); + "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x112x32_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x128x32_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -// GMMA 64x128x32 TN F16+=E4M3*E5M2 -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -struct MMA_64x128x32_F16E4M3E5M2_SS_TN +// GMMA 64x192x32 TN S32+=U8*S8 +struct MMA_64x192x32_S32U8S8_RS_TN { using DRegisters = void; - using ARegisters = uint64_t[1]; + using ARegisters = uint32_t[4]; using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[32]; + using CRegisters = uint32_t[96]; CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, uint64_t const& desc_b, uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, @@ -38087,22 +10468,46 @@ struct MMA_64x128x32_F16E4M3E5M2_SS_TN uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + uint32_t & d88, uint32_t & d89, uint32_t & d90, uint32_t & d91, + uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %34, 0;\n" - "wgmma.mma_async.sync.aligned.m64n128k32.f16.e4m3.e5m2 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31}," - " %32," - " %33," - " p, %35, %36;\n" + "setp.ne.b32 p, %101, 0;\n" + "wgmma.mma_async.sync.aligned.m64n192k32.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + "{%96, %97, %98, %99}," + " %100," + " p;\n" "}\n" : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), @@ -38111,29 +10516,41 @@ struct MMA_64x128x32_F16E4M3E5M2_SS_TN "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) - : "l"(desc_a), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87), + "+r"(d88), "+r"(d89), "+r"(d90), "+r"(d91), + "+r"(d92), "+r"(d93), "+r"(d94), "+r"(d95) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), "l"(desc_b), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); + "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x128x32_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x192x32_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// GMMA 64x128x32 TN F16+=E4M3*E5M2 -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -struct MMA_64x128x32_F16E4M3E5M2_RS_TN +// GMMA 64x192x32 TN S32+=U8*S8 +struct MMA_64x192x32_S32U8S8_RS_TN_SATURATE { using DRegisters = void; using ARegisters = uint32_t[4]; using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[32]; + using CRegisters = uint32_t[96]; CUTE_HOST_DEVICE static void fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, @@ -38146,6 +10563,22 @@ struct MMA_64x128x32_F16E4M3E5M2_RS_TN uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + uint32_t & d88, uint32_t & d89, uint32_t & d90, uint32_t & d91, + uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) @@ -38153,15 +10586,23 @@ struct MMA_64x128x32_F16E4M3E5M2_RS_TN asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %37, 0;\n" - "wgmma.mma_async.sync.aligned.m64n128k32.f16.e4m3.e5m2 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31}," - "{%32, %33, %34, %35}," - " %36," - " p, %38, %39;\n" + "setp.ne.b32 p, %101, 0;\n" + "wgmma.mma_async.sync.aligned.m64n192k32.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + "{%96, %97, %98, %99}," + " %100," + " p;\n" "}\n" : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), @@ -38170,58 +10611,201 @@ struct MMA_64x128x32_F16E4M3E5M2_RS_TN "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87), + "+r"(d88), "+r"(d89), "+r"(d90), "+r"(d91), + "+r"(d92), "+r"(d93), "+r"(d94), "+r"(d95) : "r"(a00), "r"(a01), "r"(a02), "r"(a03), "l"(desc_b), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x192x32_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x256x32 TN S32+=U8*S8 +struct MMA_64x256x32_S32U8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[128]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + uint32_t & d120, uint32_t & d121, uint32_t & d122, uint32_t & d123, + uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %133, 0;\n" + "wgmma.mma_async.sync.aligned.m64n256k32.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + "{%128, %129, %130, %131}," + " %132," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119), + "+r"(d120), "+r"(d121), "+r"(d122), "+r"(d123), + "+r"(d124), "+r"(d125), "+r"(d126), "+r"(d127) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x128x32_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x256x32_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// GMMA 64x128x32 TN F32+=E4M3*E5M2 -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -struct MMA_64x128x32_F32E4M3E5M2_SS_TN +// GMMA 64x256x32 TN S32+=U8*S8 +struct MMA_64x256x32_S32U8S8_RS_TN_SATURATE { using DRegisters = void; - using ARegisters = uint64_t[1]; + using ARegisters = uint32_t[4]; using BRegisters = uint64_t[1]; - using CRegisters = float[64]; + using CRegisters = uint32_t[128]; CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, uint64_t const& desc_b, - float & d00, float & d01, float & d02, float & d03, - float & d04, float & d05, float & d06, float & d07, - float & d08, float & d09, float & d10, float & d11, - float & d12, float & d13, float & d14, float & d15, - float & d16, float & d17, float & d18, float & d19, - float & d20, float & d21, float & d22, float & d23, - float & d24, float & d25, float & d26, float & d27, - float & d28, float & d29, float & d30, float & d31, - float & d32, float & d33, float & d34, float & d35, - float & d36, float & d37, float & d38, float & d39, - float & d40, float & d41, float & d42, float & d43, - float & d44, float & d45, float & d46, float & d47, - float & d48, float & d49, float & d50, float & d51, - float & d52, float & d53, float & d54, float & d55, - float & d56, float & d57, float & d58, float & d59, - float & d60, float & d61, float & d62, float & d63, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + uint32_t & d120, uint32_t & d121, uint32_t & d122, uint32_t & d123, + uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %66, 0;\n" - "wgmma.mma_async.sync.aligned.m64n128k32.f32.e4m3.e5m2 " + "setp.ne.b32 p, %133, 0;\n" + "wgmma.mma_async.sync.aligned.m64n256k32.s32.u8.s8.satfinite " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " " %16, %17, %18, %19, %20, %21, %22, %23, " @@ -38229,142 +10813,112 @@ struct MMA_64x128x32_F32E4M3E5M2_SS_TN " %32, %33, %34, %35, %36, %37, %38, %39, " " %40, %41, %42, %43, %44, %45, %46, %47, " " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63}," - " %64," - " %65," - " p, %67, %68;\n" + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + "{%128, %129, %130, %131}," + " %132," + " p;\n" "}\n" - : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), - "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), - "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), - "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), - "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), - "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), - "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), - "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), - "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), - "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), - "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), - "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), - "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), - "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), - "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), - "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63) - : "l"(desc_a), + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119), + "+r"(d120), "+r"(d121), "+r"(d122), "+r"(d123), + "+r"(d124), "+r"(d125), "+r"(d126), "+r"(d127) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), "l"(desc_b), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); + "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x128x32_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x256x32_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// GMMA 64x128x32 TN F32+=E4M3*E5M2 -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -struct MMA_64x128x32_F32E4M3E5M2_RS_TN +// GMMA 64x8x32 TN S32+=U8*U8 +struct MMA_64x8x32_S32U8U8_SS_TN { using DRegisters = void; - using ARegisters = uint32_t[4]; + using ARegisters = uint64_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = float[64]; + using CRegisters = uint32_t[4]; CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + fma(uint64_t const& desc_a, uint64_t const& desc_b, - float & d00, float & d01, float & d02, float & d03, - float & d04, float & d05, float & d06, float & d07, - float & d08, float & d09, float & d10, float & d11, - float & d12, float & d13, float & d14, float & d15, - float & d16, float & d17, float & d18, float & d19, - float & d20, float & d21, float & d22, float & d23, - float & d24, float & d25, float & d26, float & d27, - float & d28, float & d29, float & d30, float & d31, - float & d32, float & d33, float & d34, float & d35, - float & d36, float & d37, float & d38, float & d39, - float & d40, float & d41, float & d42, float & d43, - float & d44, float & d45, float & d46, float & d47, - float & d48, float & d49, float & d50, float & d51, - float & d52, float & d53, float & d54, float & d55, - float & d56, float & d57, float & d58, float & d59, - float & d60, float & d61, float & d62, float & d63, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %69, 0;\n" - "wgmma.mma_async.sync.aligned.m64n128k32.f32.e4m3.e5m2 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63}," - "{%64, %65, %66, %67}," - " %68," - " p, %70, %71;\n" + "setp.ne.b32 p, %6, 0;\n" + "wgmma.mma_async.sync.aligned.m64n8k32.s32.u8.u8 " + "{%0, %1, %2, %3}," + " %4," + " %5," + " p;\n" "}\n" - : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), - "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), - "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), - "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), - "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), - "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), - "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), - "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), - "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), - "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), - "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), - "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), - "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), - "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), - "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), - "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) + : "l"(desc_a), "l"(desc_b), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); + "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x128x32_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x8x32_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x144x32 TN F16+=E4M3*E5M2 -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -struct MMA_64x144x32_F16E4M3E5M2_SS_TN +// GMMA 64x8x32 TN S32+=U8*U8 +struct MMA_64x8x32_S32U8U8_SS_TN_SATURATE { using DRegisters = void; using ARegisters = uint64_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[36]; + using CRegisters = uint32_t[4]; CUTE_HOST_DEVICE static void fma(uint64_t const& desc_a, uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) @@ -38372,136 +10926,78 @@ struct MMA_64x144x32_F16E4M3E5M2_SS_TN asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %38, 0;\n" - "wgmma.mma_async.sync.aligned.m64n144k32.f16.e4m3.e5m2 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35}," - " %36," - " %37," - " p, %39, %40;\n" + "setp.ne.b32 p, %6, 0;\n" + "wgmma.mma_async.sync.aligned.m64n8k32.s32.u8.u8.satfinite " + "{%0, %1, %2, %3}," + " %4," + " %5," + " p;\n" "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35) + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) : "l"(desc_a), "l"(desc_b), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); + "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x144x32_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x8x32_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x144x32 TN F16+=E4M3*E5M2 -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -struct MMA_64x144x32_F16E4M3E5M2_RS_TN +// GMMA 64x16x32 TN S32+=U8*U8 +struct MMA_64x16x32_S32U8U8_SS_TN { using DRegisters = void; - using ARegisters = uint32_t[4]; + using ARegisters = uint64_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[36]; + using CRegisters = uint32_t[8]; CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + fma(uint64_t const& desc_a, uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %41, 0;\n" - "wgmma.mma_async.sync.aligned.m64n144k32.f16.e4m3.e5m2 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35}," - "{%36, %37, %38, %39}," - " %40," - " p, %42, %43;\n" + "setp.ne.b32 p, %10, 0;\n" + "wgmma.mma_async.sync.aligned.m64n16k32.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + " %8," + " %9," + " p;\n" "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) + : "l"(desc_a), "l"(desc_b), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); + "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x144x32_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x16x32_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x144x32 TN F32+=E4M3*E5M2 -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -struct MMA_64x144x32_F32E4M3E5M2_SS_TN +// GMMA 64x16x32 TN S32+=U8*U8 +struct MMA_64x16x32_S32U8U8_SS_TN_SATURATE { using DRegisters = void; using ARegisters = uint64_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = float[72]; + using CRegisters = uint32_t[8]; CUTE_HOST_DEVICE static void fma(uint64_t const& desc_a, uint64_t const& desc_b, - float & d00, float & d01, float & d02, float & d03, - float & d04, float & d05, float & d06, float & d07, - float & d08, float & d09, float & d10, float & d11, - float & d12, float & d13, float & d14, float & d15, - float & d16, float & d17, float & d18, float & d19, - float & d20, float & d21, float & d22, float & d23, - float & d24, float & d25, float & d26, float & d27, - float & d28, float & d29, float & d30, float & d31, - float & d32, float & d33, float & d34, float & d35, - float & d36, float & d37, float & d38, float & d39, - float & d40, float & d41, float & d42, float & d43, - float & d44, float & d45, float & d46, float & d47, - float & d48, float & d49, float & d50, float & d51, - float & d52, float & d53, float & d54, float & d55, - float & d56, float & d57, float & d58, float & d59, - float & d60, float & d61, float & d62, float & d63, - float & d64, float & d65, float & d66, float & d67, - float & d68, float & d69, float & d70, float & d71, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) @@ -38509,149 +11005,78 @@ struct MMA_64x144x32_F32E4M3E5M2_SS_TN asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %74, 0;\n" - "wgmma.mma_async.sync.aligned.m64n144k32.f32.e4m3.e5m2 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71}," - " %72," - " %73," - " p, %75, %76;\n" + "setp.ne.b32 p, %10, 0;\n" + "wgmma.mma_async.sync.aligned.m64n16k32.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + " %8," + " %9," + " p;\n" "}\n" - : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), - "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), - "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), - "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), - "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), - "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), - "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), - "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), - "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), - "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), - "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), - "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), - "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), - "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), - "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), - "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), - "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), - "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71) + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) : "l"(desc_a), "l"(desc_b), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); + "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x144x32_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x16x32_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x144x32 TN F32+=E4M3*E5M2 -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -struct MMA_64x144x32_F32E4M3E5M2_RS_TN +// GMMA 64x32x32 TN S32+=U8*U8 +struct MMA_64x32x32_S32U8U8_SS_TN { using DRegisters = void; - using ARegisters = uint32_t[4]; + using ARegisters = uint64_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = float[72]; + using CRegisters = uint32_t[16]; CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + fma(uint64_t const& desc_a, uint64_t const& desc_b, - float & d00, float & d01, float & d02, float & d03, - float & d04, float & d05, float & d06, float & d07, - float & d08, float & d09, float & d10, float & d11, - float & d12, float & d13, float & d14, float & d15, - float & d16, float & d17, float & d18, float & d19, - float & d20, float & d21, float & d22, float & d23, - float & d24, float & d25, float & d26, float & d27, - float & d28, float & d29, float & d30, float & d31, - float & d32, float & d33, float & d34, float & d35, - float & d36, float & d37, float & d38, float & d39, - float & d40, float & d41, float & d42, float & d43, - float & d44, float & d45, float & d46, float & d47, - float & d48, float & d49, float & d50, float & d51, - float & d52, float & d53, float & d54, float & d55, - float & d56, float & d57, float & d58, float & d59, - float & d60, float & d61, float & d62, float & d63, - float & d64, float & d65, float & d66, float & d67, - float & d68, float & d69, float & d70, float & d71, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %77, 0;\n" - "wgmma.mma_async.sync.aligned.m64n144k32.f32.e4m3.e5m2 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71}," - "{%72, %73, %74, %75}," - " %76," - " p, %78, %79;\n" + "setp.ne.b32 p, %18, 0;\n" + "wgmma.mma_async.sync.aligned.m64n32k32.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + " %16," + " %17," + " p;\n" "}\n" - : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), - "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), - "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), - "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), - "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), - "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), - "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), - "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), - "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), - "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), - "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), - "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), - "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), - "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), - "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), - "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), - "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), - "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) + : "l"(desc_a), "l"(desc_b), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); + "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x144x32_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x32x32_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x160x32 TN F16+=E4M3*E5M2 -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -struct MMA_64x160x32_F16E4M3E5M2_SS_TN +// GMMA 64x32x32 TN S32+=U8*U8 +struct MMA_64x32x32_S32U8U8_SS_TN_SATURATE { using DRegisters = void; using ARegisters = uint64_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[40]; + using CRegisters = uint32_t[16]; CUTE_HOST_DEVICE static void fma(uint64_t const& desc_a, @@ -38660,12 +11085,6 @@ struct MMA_64x160x32_F16E4M3E5M2_SS_TN uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) @@ -38673,54 +11092,39 @@ struct MMA_64x160x32_F16E4M3E5M2_SS_TN asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %42, 0;\n" - "wgmma.mma_async.sync.aligned.m64n160k32.f16.e4m3.e5m2 " + "setp.ne.b32 p, %18, 0;\n" + "wgmma.mma_async.sync.aligned.m64n32k32.s32.u8.u8.satfinite " "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39}," - " %40," - " %41," - " p, %43, %44;\n" + " %8, %9, %10, %11, %12, %13, %14, %15}," + " %16," + " %17," + " p;\n" "}\n" : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) : "l"(desc_a), "l"(desc_b), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); + "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x160x32_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x32x32_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x160x32 TN F16+=E4M3*E5M2 -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -struct MMA_64x160x32_F16E4M3E5M2_RS_TN +// GMMA 64x64x32 TN S32+=U8*U8 +struct MMA_64x64x32_S32U8U8_SS_TN { using DRegisters = void; - using ARegisters = uint32_t[4]; + using ARegisters = uint64_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[40]; + using CRegisters = uint32_t[32]; CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + fma(uint64_t const& desc_a, uint64_t const& desc_b, uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, @@ -38730,25 +11134,22 @@ struct MMA_64x160x32_F16E4M3E5M2_RS_TN uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %45, 0;\n" - "wgmma.mma_async.sync.aligned.m64n160k32.f16.e4m3.e5m2 " + "setp.ne.b32 p, %34, 0;\n" + "wgmma.mma_async.sync.aligned.m64n64k32.s32.u8.u8 " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39}," - "{%40, %41, %42, %43}," - " %44," - " p, %46, %47;\n" + " %24, %25, %26, %27, %28, %29, %30, %31}," + " %32," + " %33," + " p;\n" "}\n" : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), @@ -38757,57 +11158,37 @@ struct MMA_64x160x32_F16E4M3E5M2_RS_TN "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) + : "l"(desc_a), "l"(desc_b), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); + "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x160x32_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x64x32_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x160x32 TN F32+=E4M3*E5M2 -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -struct MMA_64x160x32_F32E4M3E5M2_SS_TN +// GMMA 64x64x32 TN S32+=U8*U8 +struct MMA_64x64x32_S32U8U8_SS_TN_SATURATE { using DRegisters = void; using ARegisters = uint64_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = float[80]; + using CRegisters = uint32_t[32]; CUTE_HOST_DEVICE static void fma(uint64_t const& desc_a, uint64_t const& desc_b, - float & d00, float & d01, float & d02, float & d03, - float & d04, float & d05, float & d06, float & d07, - float & d08, float & d09, float & d10, float & d11, - float & d12, float & d13, float & d14, float & d15, - float & d16, float & d17, float & d18, float & d19, - float & d20, float & d21, float & d22, float & d23, - float & d24, float & d25, float & d26, float & d27, - float & d28, float & d29, float & d30, float & d31, - float & d32, float & d33, float & d34, float & d35, - float & d36, float & d37, float & d38, float & d39, - float & d40, float & d41, float & d42, float & d43, - float & d44, float & d45, float & d46, float & d47, - float & d48, float & d49, float & d50, float & d51, - float & d52, float & d53, float & d54, float & d55, - float & d56, float & d57, float & d58, float & d59, - float & d60, float & d61, float & d62, float & d63, - float & d64, float & d65, float & d66, float & d67, - float & d68, float & d69, float & d70, float & d71, - float & d72, float & d73, float & d74, float & d75, - float & d76, float & d77, float & d78, float & d79, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) @@ -38815,157 +11196,42 @@ struct MMA_64x160x32_F32E4M3E5M2_SS_TN asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %82, 0;\n" - "wgmma.mma_async.sync.aligned.m64n160k32.f32.e4m3.e5m2 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79}," - " %80," - " %81," - " p, %83, %84;\n" + "setp.ne.b32 p, %34, 0;\n" + "wgmma.mma_async.sync.aligned.m64n64k32.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + " %32," + " %33," + " p;\n" "}\n" - : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), - "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), - "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), - "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), - "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), - "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), - "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), - "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), - "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), - "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), - "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), - "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), - "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), - "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), - "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), - "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), - "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), - "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), - "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), - "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79) + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) : "l"(desc_a), "l"(desc_b), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x160x32_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x160x32 TN F32+=E4M3*E5M2 -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -struct MMA_64x160x32_F32E4M3E5M2_RS_TN -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using BRegisters = uint64_t[1]; - using CRegisters = float[80]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, - uint64_t const& desc_b, - float & d00, float & d01, float & d02, float & d03, - float & d04, float & d05, float & d06, float & d07, - float & d08, float & d09, float & d10, float & d11, - float & d12, float & d13, float & d14, float & d15, - float & d16, float & d17, float & d18, float & d19, - float & d20, float & d21, float & d22, float & d23, - float & d24, float & d25, float & d26, float & d27, - float & d28, float & d29, float & d30, float & d31, - float & d32, float & d33, float & d34, float & d35, - float & d36, float & d37, float & d38, float & d39, - float & d40, float & d41, float & d42, float & d43, - float & d44, float & d45, float & d46, float & d47, - float & d48, float & d49, float & d50, float & d51, - float & d52, float & d53, float & d54, float & d55, - float & d56, float & d57, float & d58, float & d59, - float & d60, float & d61, float & d62, float & d63, - float & d64, float & d65, float & d66, float & d67, - float & d68, float & d69, float & d70, float & d71, - float & d72, float & d73, float & d74, float & d75, - float & d76, float & d77, float & d78, float & d79, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %85, 0;\n" - "wgmma.mma_async.sync.aligned.m64n160k32.f32.e4m3.e5m2 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79}," - "{%80, %81, %82, %83}," - " %84," - " p, %86, %87;\n" - "}\n" - : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), - "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), - "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), - "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), - "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), - "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), - "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), - "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), - "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), - "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), - "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), - "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), - "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), - "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), - "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), - "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), - "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), - "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), - "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), - "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), - "l"(desc_b), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); + "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x160x32_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x64x32_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x176x32 TN F16+=E4M3*E5M2 -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -struct MMA_64x176x32_F16E4M3E5M2_SS_TN +// GMMA 64x96x32 TN S32+=U8*U8 +struct MMA_64x96x32_S32U8U8_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[44]; + using CRegisters = uint32_t[48]; CUTE_HOST_DEVICE static void fma(uint64_t const& desc_a, @@ -38981,6 +11247,7 @@ struct MMA_64x176x32_F16E4M3E5M2_SS_TN uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) @@ -38988,17 +11255,17 @@ struct MMA_64x176x32_F16E4M3E5M2_SS_TN asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %46, 0;\n" - "wgmma.mma_async.sync.aligned.m64n176k32.f16.e4m3.e5m2 " + "setp.ne.b32 p, %50, 0;\n" + "wgmma.mma_async.sync.aligned.m64n96k32.s32.u8.u8 " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " " %16, %17, %18, %19, %20, %21, %22, %23, " " %24, %25, %26, %27, %28, %29, %30, %31, " " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43}," - " %44," - " %45," - " p, %47, %48;\n" + " %40, %41, %42, %43, %44, %45, %46, %47}," + " %48," + " %49," + " p;\n" "}\n" : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), @@ -39010,34 +11277,29 @@ struct MMA_64x176x32_F16E4M3E5M2_SS_TN "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43) + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) : "l"(desc_a), "l"(desc_b), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); + "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x176x32_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x96x32_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x176x32 TN F16+=E4M3*E5M2 -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -struct MMA_64x176x32_F16E4M3E5M2_RS_TN +// GMMA 64x96x32 TN S32+=U8*U8 +struct MMA_64x96x32_S32U8U8_SS_TN_SATURATE { using DRegisters = void; - using ARegisters = uint32_t[4]; + using ARegisters = uint64_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[44]; + using CRegisters = uint32_t[48]; CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + fma(uint64_t const& desc_a, uint64_t const& desc_b, uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, @@ -39050,24 +11312,25 @@ struct MMA_64x176x32_F16E4M3E5M2_RS_TN uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %49, 0;\n" - "wgmma.mma_async.sync.aligned.m64n176k32.f16.e4m3.e5m2 " + "setp.ne.b32 p, %50, 0;\n" + "wgmma.mma_async.sync.aligned.m64n96k32.s32.u8.u8.satfinite " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " " %16, %17, %18, %19, %20, %21, %22, %23, " " %24, %25, %26, %27, %28, %29, %30, %31, " " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43}," - "{%44, %45, %46, %47}," + " %40, %41, %42, %43, %44, %45, %46, %47}," " %48," - " p, %50, %51;\n" + " %49," + " p;\n" "}\n" : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), @@ -39079,57 +11342,46 @@ struct MMA_64x176x32_F16E4M3E5M2_RS_TN "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) + : "l"(desc_a), "l"(desc_b), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); + "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x176x32_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x96x32_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x176x32 TN F32+=E4M3*E5M2 -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -struct MMA_64x176x32_F32E4M3E5M2_SS_TN +// GMMA 64x128x32 TN S32+=U8*U8 +struct MMA_64x128x32_S32U8U8_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = float[88]; + using CRegisters = uint32_t[64]; CUTE_HOST_DEVICE static void fma(uint64_t const& desc_a, uint64_t const& desc_b, - float & d00, float & d01, float & d02, float & d03, - float & d04, float & d05, float & d06, float & d07, - float & d08, float & d09, float & d10, float & d11, - float & d12, float & d13, float & d14, float & d15, - float & d16, float & d17, float & d18, float & d19, - float & d20, float & d21, float & d22, float & d23, - float & d24, float & d25, float & d26, float & d27, - float & d28, float & d29, float & d30, float & d31, - float & d32, float & d33, float & d34, float & d35, - float & d36, float & d37, float & d38, float & d39, - float & d40, float & d41, float & d42, float & d43, - float & d44, float & d45, float & d46, float & d47, - float & d48, float & d49, float & d50, float & d51, - float & d52, float & d53, float & d54, float & d55, - float & d56, float & d57, float & d58, float & d59, - float & d60, float & d61, float & d62, float & d63, - float & d64, float & d65, float & d66, float & d67, - float & d68, float & d69, float & d70, float & d71, - float & d72, float & d73, float & d74, float & d75, - float & d76, float & d77, float & d78, float & d79, - float & d80, float & d81, float & d82, float & d83, - float & d84, float & d85, float & d86, float & d87, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) @@ -39137,8 +11389,8 @@ struct MMA_64x176x32_F32E4M3E5M2_SS_TN asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %90, 0;\n" - "wgmma.mma_async.sync.aligned.m64n176k32.f32.e4m3.e5m2 " + "setp.ne.b32 p, %66, 0;\n" + "wgmma.mma_async.sync.aligned.m64n128k32.s32.u8.u8 " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " " %16, %17, %18, %19, %20, %21, %22, %23, " @@ -39146,95 +11398,74 @@ struct MMA_64x176x32_F32E4M3E5M2_SS_TN " %32, %33, %34, %35, %36, %37, %38, %39, " " %40, %41, %42, %43, %44, %45, %46, %47, " " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87}," - " %88," - " %89," - " p, %91, %92;\n" + " %56, %57, %58, %59, %60, %61, %62, %63}," + " %64," + " %65," + " p;\n" "}\n" - : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), - "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), - "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), - "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), - "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), - "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), - "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), - "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), - "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), - "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), - "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), - "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), - "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), - "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), - "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), - "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), - "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), - "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), - "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), - "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), - "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), - "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87) + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) : "l"(desc_a), "l"(desc_b), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); + "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x176x32_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x128x32_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x176x32 TN F32+=E4M3*E5M2 -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -struct MMA_64x176x32_F32E4M3E5M2_RS_TN +// GMMA 64x128x32 TN S32+=U8*U8 +struct MMA_64x128x32_S32U8U8_SS_TN_SATURATE { using DRegisters = void; - using ARegisters = uint32_t[4]; + using ARegisters = uint64_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = float[88]; + using CRegisters = uint32_t[64]; CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + fma(uint64_t const& desc_a, uint64_t const& desc_b, - float & d00, float & d01, float & d02, float & d03, - float & d04, float & d05, float & d06, float & d07, - float & d08, float & d09, float & d10, float & d11, - float & d12, float & d13, float & d14, float & d15, - float & d16, float & d17, float & d18, float & d19, - float & d20, float & d21, float & d22, float & d23, - float & d24, float & d25, float & d26, float & d27, - float & d28, float & d29, float & d30, float & d31, - float & d32, float & d33, float & d34, float & d35, - float & d36, float & d37, float & d38, float & d39, - float & d40, float & d41, float & d42, float & d43, - float & d44, float & d45, float & d46, float & d47, - float & d48, float & d49, float & d50, float & d51, - float & d52, float & d53, float & d54, float & d55, - float & d56, float & d57, float & d58, float & d59, - float & d60, float & d61, float & d62, float & d63, - float & d64, float & d65, float & d66, float & d67, - float & d68, float & d69, float & d70, float & d71, - float & d72, float & d73, float & d74, float & d75, - float & d76, float & d77, float & d78, float & d79, - float & d80, float & d81, float & d82, float & d83, - float & d84, float & d85, float & d86, float & d87, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %93, 0;\n" - "wgmma.mma_async.sync.aligned.m64n176k32.f32.e4m3.e5m2 " + "setp.ne.b32 p, %66, 0;\n" + "wgmma.mma_async.sync.aligned.m64n128k32.s32.u8.u8.satfinite " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " " %16, %17, %18, %19, %20, %21, %22, %23, " @@ -39242,59 +11473,45 @@ struct MMA_64x176x32_F32E4M3E5M2_RS_TN " %32, %33, %34, %35, %36, %37, %38, %39, " " %40, %41, %42, %43, %44, %45, %46, %47, " " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87}," - "{%88, %89, %90, %91}," - " %92," - " p, %94, %95;\n" + " %56, %57, %58, %59, %60, %61, %62, %63}," + " %64," + " %65," + " p;\n" "}\n" - : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), - "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), - "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), - "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), - "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), - "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), - "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), - "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), - "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), - "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), - "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), - "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), - "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), - "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), - "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), - "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), - "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), - "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), - "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), - "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), - "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), - "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) + : "l"(desc_a), "l"(desc_b), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); + "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x176x32_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x128x32_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -// GMMA 64x192x32 TN F16+=E4M3*E5M2 -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -struct MMA_64x192x32_F16E4M3E5M2_SS_TN +// GMMA 64x192x32 TN S32+=U8*U8 +struct MMA_64x192x32_S32U8U8_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[48]; + using CRegisters = uint32_t[96]; CUTE_HOST_DEVICE static void fma(uint64_t const& desc_a, @@ -39311,6 +11528,18 @@ struct MMA_64x192x32_F16E4M3E5M2_SS_TN uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + uint32_t & d88, uint32_t & d89, uint32_t & d90, uint32_t & d91, + uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) @@ -39318,17 +11547,23 @@ struct MMA_64x192x32_F16E4M3E5M2_SS_TN asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %50, 0;\n" - "wgmma.mma_async.sync.aligned.m64n192k32.f16.e4m3.e5m2 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47}," - " %48," - " %49," - " p, %51, %52;\n" + "setp.ne.b32 p, %98, 0;\n" + "wgmma.mma_async.sync.aligned.m64n192k32.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + " %96," + " %97," + " p;\n" "}\n" : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), @@ -39341,32 +11576,40 @@ struct MMA_64x192x32_F16E4M3E5M2_SS_TN "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), - "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87), + "+r"(d88), "+r"(d89), "+r"(d90), "+r"(d91), + "+r"(d92), "+r"(d93), "+r"(d94), "+r"(d95) : "l"(desc_a), "l"(desc_b), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); + "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x192x32_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x192x32_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// GMMA 64x192x32 TN F16+=E4M3*E5M2 -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -struct MMA_64x192x32_F16E4M3E5M2_RS_TN +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x192x32 TN S32+=U8*U8 +struct MMA_64x192x32_S32U8U8_SS_TN_SATURATE { using DRegisters = void; - using ARegisters = uint32_t[4]; + using ARegisters = uint64_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[48]; + using CRegisters = uint32_t[96]; CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + fma(uint64_t const& desc_a, uint64_t const& desc_b, uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, @@ -39380,24 +11623,42 @@ struct MMA_64x192x32_F16E4M3E5M2_RS_TN uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + uint32_t & d88, uint32_t & d89, uint32_t & d90, uint32_t & d91, + uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %53, 0;\n" - "wgmma.mma_async.sync.aligned.m64n192k32.f16.e4m3.e5m2 " + "setp.ne.b32 p, %98, 0;\n" + "wgmma.mma_async.sync.aligned.m64n192k32.s32.u8.u8.satfinite " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " " %16, %17, %18, %19, %20, %21, %22, %23, " " %24, %25, %26, %27, %28, %29, %30, %31, " " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47}," - "{%48, %49, %50, %51}," - " %52," - " p, %54, %55;\n" + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + " %96," + " %97," + " p;\n" "}\n" : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), @@ -39410,57 +11671,73 @@ struct MMA_64x192x32_F16E4M3E5M2_RS_TN "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), - "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87), + "+r"(d88), "+r"(d89), "+r"(d90), "+r"(d91), + "+r"(d92), "+r"(d93), "+r"(d94), "+r"(d95) + : "l"(desc_a), "l"(desc_b), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); + "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x192x32_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x192x32_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// GMMA 64x192x32 TN F32+=E4M3*E5M2 -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -struct MMA_64x192x32_F32E4M3E5M2_SS_TN +// GMMA 64x256x32 TN S32+=U8*U8 +struct MMA_64x256x32_S32U8U8_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = float[96]; + using CRegisters = uint32_t[128]; CUTE_HOST_DEVICE static void fma(uint64_t const& desc_a, uint64_t const& desc_b, - float & d00, float & d01, float & d02, float & d03, - float & d04, float & d05, float & d06, float & d07, - float & d08, float & d09, float & d10, float & d11, - float & d12, float & d13, float & d14, float & d15, - float & d16, float & d17, float & d18, float & d19, - float & d20, float & d21, float & d22, float & d23, - float & d24, float & d25, float & d26, float & d27, - float & d28, float & d29, float & d30, float & d31, - float & d32, float & d33, float & d34, float & d35, - float & d36, float & d37, float & d38, float & d39, - float & d40, float & d41, float & d42, float & d43, - float & d44, float & d45, float & d46, float & d47, - float & d48, float & d49, float & d50, float & d51, - float & d52, float & d53, float & d54, float & d55, - float & d56, float & d57, float & d58, float & d59, - float & d60, float & d61, float & d62, float & d63, - float & d64, float & d65, float & d66, float & d67, - float & d68, float & d69, float & d70, float & d71, - float & d72, float & d73, float & d74, float & d75, - float & d76, float & d77, float & d78, float & d79, - float & d80, float & d81, float & d82, float & d83, - float & d84, float & d85, float & d86, float & d87, - float & d88, float & d89, float & d90, float & d91, - float & d92, float & d93, float & d94, float & d95, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + uint32_t & d120, uint32_t & d121, uint32_t & d122, uint32_t & d123, + uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) @@ -39468,8 +11745,8 @@ struct MMA_64x192x32_F32E4M3E5M2_SS_TN asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %98, 0;\n" - "wgmma.mma_async.sync.aligned.m64n192k32.f32.e4m3.e5m2 " + "setp.ne.b32 p, %130, 0;\n" + "wgmma.mma_async.sync.aligned.m64n256k32.s32.u8.u8 " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " " %16, %17, %18, %19, %20, %21, %22, %23, " @@ -39481,94 +11758,110 @@ struct MMA_64x192x32_F32E4M3E5M2_SS_TN " %64, %65, %66, %67, %68, %69, %70, %71, " " %72, %73, %74, %75, %76, %77, %78, %79, " " %80, %81, %82, %83, %84, %85, %86, %87, " - " %88, %89, %90, %91, %92, %93, %94, %95}," - " %96," - " %97," - " p, %99, %100;\n" + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + " %128," + " %129," + " p;\n" "}\n" - : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), - "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), - "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), - "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), - "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), - "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), - "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), - "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), - "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), - "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), - "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), - "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), - "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), - "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), - "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), - "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), - "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), - "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), - "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), - "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), - "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), - "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), - "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91), - "+f"(d92), "+f"(d93), "+f"(d94), "+f"(d95) + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119), + "+r"(d120), "+r"(d121), "+r"(d122), "+r"(d123), + "+r"(d124), "+r"(d125), "+r"(d126), "+r"(d127) : "l"(desc_a), "l"(desc_b), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); + "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x192x32_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x256x32_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// GMMA 64x192x32 TN F32+=E4M3*E5M2 -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -struct MMA_64x192x32_F32E4M3E5M2_RS_TN +// GMMA 64x256x32 TN S32+=U8*U8 +struct MMA_64x256x32_S32U8U8_SS_TN_SATURATE { using DRegisters = void; - using ARegisters = uint32_t[4]; + using ARegisters = uint64_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = float[96]; + using CRegisters = uint32_t[128]; CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + fma(uint64_t const& desc_a, uint64_t const& desc_b, - float & d00, float & d01, float & d02, float & d03, - float & d04, float & d05, float & d06, float & d07, - float & d08, float & d09, float & d10, float & d11, - float & d12, float & d13, float & d14, float & d15, - float & d16, float & d17, float & d18, float & d19, - float & d20, float & d21, float & d22, float & d23, - float & d24, float & d25, float & d26, float & d27, - float & d28, float & d29, float & d30, float & d31, - float & d32, float & d33, float & d34, float & d35, - float & d36, float & d37, float & d38, float & d39, - float & d40, float & d41, float & d42, float & d43, - float & d44, float & d45, float & d46, float & d47, - float & d48, float & d49, float & d50, float & d51, - float & d52, float & d53, float & d54, float & d55, - float & d56, float & d57, float & d58, float & d59, - float & d60, float & d61, float & d62, float & d63, - float & d64, float & d65, float & d66, float & d67, - float & d68, float & d69, float & d70, float & d71, - float & d72, float & d73, float & d74, float & d75, - float & d76, float & d77, float & d78, float & d79, - float & d80, float & d81, float & d82, float & d83, - float & d84, float & d85, float & d86, float & d87, - float & d88, float & d89, float & d90, float & d91, - float & d92, float & d93, float & d94, float & d95, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + uint32_t & d120, uint32_t & d121, uint32_t & d122, uint32_t & d123, + uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %101, 0;\n" - "wgmma.mma_async.sync.aligned.m64n192k32.f32.e4m3.e5m2 " + "setp.ne.b32 p, %130, 0;\n" + "wgmma.mma_async.sync.aligned.m64n256k32.s32.u8.u8.satfinite " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " " %16, %17, %18, %19, %20, %21, %22, %23, " @@ -39580,149 +11873,108 @@ struct MMA_64x192x32_F32E4M3E5M2_RS_TN " %64, %65, %66, %67, %68, %69, %70, %71, " " %72, %73, %74, %75, %76, %77, %78, %79, " " %80, %81, %82, %83, %84, %85, %86, %87, " - " %88, %89, %90, %91, %92, %93, %94, %95}," - "{%96, %97, %98, %99}," - " %100," - " p, %102, %103;\n" + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + " %128," + " %129," + " p;\n" "}\n" - : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), - "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), - "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), - "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), - "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), - "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), - "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), - "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), - "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), - "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), - "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), - "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), - "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), - "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), - "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), - "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), - "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), - "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), - "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), - "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), - "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), - "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), - "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91), - "+f"(d92), "+f"(d93), "+f"(d94), "+f"(d95) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119), + "+r"(d120), "+r"(d121), "+r"(d122), "+r"(d123), + "+r"(d124), "+r"(d125), "+r"(d126), "+r"(d127) + : "l"(desc_a), "l"(desc_b), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); + "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x192x32_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x256x32_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x208x32 TN F16+=E4M3*E5M2 -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -struct MMA_64x208x32_F16E4M3E5M2_SS_TN +// GMMA 64x8x32 TN S32+=U8*U8 +struct MMA_64x8x32_S32U8U8_RS_TN { using DRegisters = void; - using ARegisters = uint64_t[1]; + using ARegisters = uint32_t[4]; using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[52]; + using CRegisters = uint32_t[4]; CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, - uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %54, 0;\n" - "wgmma.mma_async.sync.aligned.m64n208k32.f16.e4m3.e5m2 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51}," - " %52," - " %53," - " p, %55, %56;\n" + "setp.ne.b32 p, %9, 0;\n" + "wgmma.mma_async.sync.aligned.m64n8k32.s32.u8.u8 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + " %8," + " p;\n" "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), - "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), - "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51) - : "l"(desc_a), + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), "l"(desc_b), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); + "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x208x32_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x8x32_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x208x32 TN F16+=E4M3*E5M2 -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -struct MMA_64x208x32_F16E4M3E5M2_RS_TN +// GMMA 64x8x32 TN S32+=U8*U8 +struct MMA_64x8x32_S32U8U8_RS_TN_SATURATE { using DRegisters = void; using ARegisters = uint32_t[4]; using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[52]; + using CRegisters = uint32_t[4]; CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, - uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) @@ -39730,192 +11982,120 @@ struct MMA_64x208x32_F16E4M3E5M2_RS_TN asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %57, 0;\n" - "wgmma.mma_async.sync.aligned.m64n208k32.f16.e4m3.e5m2 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51}," - "{%52, %53, %54, %55}," - " %56," - " p, %58, %59;\n" + "setp.ne.b32 p, %9, 0;\n" + "wgmma.mma_async.sync.aligned.m64n8k32.s32.u8.u8.satfinite " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + " %8," + " p;\n" "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), - "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), - "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), "l"(desc_b), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); + "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x208x32_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x8x32_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x208x32 TN F32+=E4M3*E5M2 -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -struct MMA_64x208x32_F32E4M3E5M2_SS_TN +// GMMA 64x16x32 TN S32+=U8*U8 +struct MMA_64x16x32_S32U8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[8]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %13, 0;\n" + "wgmma.mma_async.sync.aligned.m64n16k32.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + "{%8, %9, %10, %11}," + " %12," + " p;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x16x32_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x16x32 TN S32+=U8*U8 +struct MMA_64x16x32_S32U8U8_RS_TN_SATURATE { using DRegisters = void; - using ARegisters = uint64_t[1]; + using ARegisters = uint32_t[4]; using BRegisters = uint64_t[1]; - using CRegisters = float[104]; + using CRegisters = uint32_t[8]; CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, uint64_t const& desc_b, - float & d000, float & d001, float & d002, float & d003, - float & d004, float & d005, float & d006, float & d007, - float & d008, float & d009, float & d010, float & d011, - float & d012, float & d013, float & d014, float & d015, - float & d016, float & d017, float & d018, float & d019, - float & d020, float & d021, float & d022, float & d023, - float & d024, float & d025, float & d026, float & d027, - float & d028, float & d029, float & d030, float & d031, - float & d032, float & d033, float & d034, float & d035, - float & d036, float & d037, float & d038, float & d039, - float & d040, float & d041, float & d042, float & d043, - float & d044, float & d045, float & d046, float & d047, - float & d048, float & d049, float & d050, float & d051, - float & d052, float & d053, float & d054, float & d055, - float & d056, float & d057, float & d058, float & d059, - float & d060, float & d061, float & d062, float & d063, - float & d064, float & d065, float & d066, float & d067, - float & d068, float & d069, float & d070, float & d071, - float & d072, float & d073, float & d074, float & d075, - float & d076, float & d077, float & d078, float & d079, - float & d080, float & d081, float & d082, float & d083, - float & d084, float & d085, float & d086, float & d087, - float & d088, float & d089, float & d090, float & d091, - float & d092, float & d093, float & d094, float & d095, - float & d096, float & d097, float & d098, float & d099, - float & d100, float & d101, float & d102, float & d103, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %106, 0;\n" - "wgmma.mma_async.sync.aligned.m64n208k32.f32.e4m3.e5m2 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87, " - " %88, %89, %90, %91, %92, %93, %94, %95, " - " %96, %97, %98, %99, %100, %101, %102, %103}," - " %104," - " %105," - " p, %107, %108;\n" + "setp.ne.b32 p, %13, 0;\n" + "wgmma.mma_async.sync.aligned.m64n16k32.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + "{%8, %9, %10, %11}," + " %12," + " p;\n" "}\n" - : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), - "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), - "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), - "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), - "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), - "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), - "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), - "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), - "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), - "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), - "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), - "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), - "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), - "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), - "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), - "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), - "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), - "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), - "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), - "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), - "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), - "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), - "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), - "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), - "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), - "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103) - : "l"(desc_a), + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), "l"(desc_b), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); + "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x208x32_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x16x32_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x208x32 TN F32+=E4M3*E5M2 -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -struct MMA_64x208x32_F32E4M3E5M2_RS_TN +// GMMA 64x32x32 TN S32+=U8*U8 +struct MMA_64x32x32_S32U8U8_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; using BRegisters = uint64_t[1]; - using CRegisters = float[104]; + using CRegisters = uint32_t[16]; CUTE_HOST_DEVICE static void - fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, uint64_t const& desc_b, - float & d000, float & d001, float & d002, float & d003, - float & d004, float & d005, float & d006, float & d007, - float & d008, float & d009, float & d010, float & d011, - float & d012, float & d013, float & d014, float & d015, - float & d016, float & d017, float & d018, float & d019, - float & d020, float & d021, float & d022, float & d023, - float & d024, float & d025, float & d026, float & d027, - float & d028, float & d029, float & d030, float & d031, - float & d032, float & d033, float & d034, float & d035, - float & d036, float & d037, float & d038, float & d039, - float & d040, float & d041, float & d042, float & d043, - float & d044, float & d045, float & d046, float & d047, - float & d048, float & d049, float & d050, float & d051, - float & d052, float & d053, float & d054, float & d055, - float & d056, float & d057, float & d058, float & d059, - float & d060, float & d061, float & d062, float & d063, - float & d064, float & d065, float & d066, float & d067, - float & d068, float & d069, float & d070, float & d071, - float & d072, float & d073, float & d074, float & d075, - float & d076, float & d077, float & d078, float & d079, - float & d080, float & d081, float & d082, float & d083, - float & d084, float & d085, float & d086, float & d087, - float & d088, float & d089, float & d090, float & d091, - float & d092, float & d093, float & d094, float & d095, - float & d096, float & d097, float & d098, float & d099, - float & d100, float & d101, float & d102, float & d103, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) @@ -39923,151 +12103,81 @@ struct MMA_64x208x32_F32E4M3E5M2_RS_TN asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %109, 0;\n" - "wgmma.mma_async.sync.aligned.m64n208k32.f32.e4m3.e5m2 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87, " - " %88, %89, %90, %91, %92, %93, %94, %95, " - " %96, %97, %98, %99, %100, %101, %102, %103}," - "{%104, %105, %106, %107}," - " %108," - " p, %110, %111;\n" + "setp.ne.b32 p, %21, 0;\n" + "wgmma.mma_async.sync.aligned.m64n32k32.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + "{%16, %17, %18, %19}," + " %20," + " p;\n" "}\n" - : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), - "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), - "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), - "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), - "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), - "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), - "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), - "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), - "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), - "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), - "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), - "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), - "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), - "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), - "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), - "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), - "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), - "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), - "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), - "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), - "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), - "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), - "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), - "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), - "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), - "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103) - : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), "l"(desc_b), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); + "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x208x32_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x32x32_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x224x32 TN F16+=E4M3*E5M2 -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -struct MMA_64x224x32_F16E4M3E5M2_SS_TN +// GMMA 64x32x32 TN S32+=U8*U8 +struct MMA_64x32x32_S32U8U8_RS_TN_SATURATE { using DRegisters = void; - using ARegisters = uint64_t[1]; + using ARegisters = uint32_t[4]; using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[56]; + using CRegisters = uint32_t[16]; CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, uint64_t const& desc_b, uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, - uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, - uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %58, 0;\n" - "wgmma.mma_async.sync.aligned.m64n224k32.f16.e4m3.e5m2 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55}," - " %56," - " %57," - " p, %59, %60;\n" + "setp.ne.b32 p, %21, 0;\n" + "wgmma.mma_async.sync.aligned.m64n32k32.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + "{%16, %17, %18, %19}," + " %20," + " p;\n" "}\n" : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), - "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), - "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), - "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) - : "l"(desc_a), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), "l"(desc_b), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); + "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x224x32_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x32x32_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x224x32 TN F16+=E4M3*E5M2 -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -struct MMA_64x224x32_F16E4M3E5M2_RS_TN +// GMMA 64x64x32 TN S32+=U8*U8 +struct MMA_64x64x32_S32U8U8_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[56]; + using CRegisters = uint32_t[32]; CUTE_HOST_DEVICE static void fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, @@ -40080,12 +12190,6 @@ struct MMA_64x224x32_F16E4M3E5M2_RS_TN uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, - uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, - uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) @@ -40093,18 +12197,15 @@ struct MMA_64x224x32_F16E4M3E5M2_RS_TN asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %61, 0;\n" - "wgmma.mma_async.sync.aligned.m64n224k32.f16.e4m3.e5m2 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55}," - "{%56, %57, %58, %59}," - " %60," - " p, %62, %63;\n" + "setp.ne.b32 p, %37, 0;\n" + "wgmma.mma_async.sync.aligned.m64n64k32.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + "{%32, %33, %34, %35}," + " %36," + " p;\n" "}\n" : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), @@ -40113,180 +12214,96 @@ struct MMA_64x224x32_F16E4M3E5M2_RS_TN "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), - "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), - "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), - "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) : "r"(a00), "r"(a01), "r"(a02), "r"(a03), "l"(desc_b), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); + "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x224x32_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x64x32_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x224x32 TN F32+=E4M3*E5M2 -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -struct MMA_64x224x32_F32E4M3E5M2_SS_TN +// GMMA 64x64x32 TN S32+=U8*U8 +struct MMA_64x64x32_S32U8U8_RS_TN_SATURATE { using DRegisters = void; - using ARegisters = uint64_t[1]; + using ARegisters = uint32_t[4]; using BRegisters = uint64_t[1]; - using CRegisters = float[112]; + using CRegisters = uint32_t[32]; CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, uint64_t const& desc_b, - float & d000, float & d001, float & d002, float & d003, - float & d004, float & d005, float & d006, float & d007, - float & d008, float & d009, float & d010, float & d011, - float & d012, float & d013, float & d014, float & d015, - float & d016, float & d017, float & d018, float & d019, - float & d020, float & d021, float & d022, float & d023, - float & d024, float & d025, float & d026, float & d027, - float & d028, float & d029, float & d030, float & d031, - float & d032, float & d033, float & d034, float & d035, - float & d036, float & d037, float & d038, float & d039, - float & d040, float & d041, float & d042, float & d043, - float & d044, float & d045, float & d046, float & d047, - float & d048, float & d049, float & d050, float & d051, - float & d052, float & d053, float & d054, float & d055, - float & d056, float & d057, float & d058, float & d059, - float & d060, float & d061, float & d062, float & d063, - float & d064, float & d065, float & d066, float & d067, - float & d068, float & d069, float & d070, float & d071, - float & d072, float & d073, float & d074, float & d075, - float & d076, float & d077, float & d078, float & d079, - float & d080, float & d081, float & d082, float & d083, - float & d084, float & d085, float & d086, float & d087, - float & d088, float & d089, float & d090, float & d091, - float & d092, float & d093, float & d094, float & d095, - float & d096, float & d097, float & d098, float & d099, - float & d100, float & d101, float & d102, float & d103, - float & d104, float & d105, float & d106, float & d107, - float & d108, float & d109, float & d110, float & d111, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %114, 0;\n" - "wgmma.mma_async.sync.aligned.m64n224k32.f32.e4m3.e5m2 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87, " - " %88, %89, %90, %91, %92, %93, %94, %95, " - " %96, %97, %98, %99, %100, %101, %102, %103, " - " %104, %105, %106, %107, %108, %109, %110, %111}," - " %112," - " %113," - " p, %115, %116;\n" + "setp.ne.b32 p, %37, 0;\n" + "wgmma.mma_async.sync.aligned.m64n64k32.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + "{%32, %33, %34, %35}," + " %36," + " p;\n" "}\n" - : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), - "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), - "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), - "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), - "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), - "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), - "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), - "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), - "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), - "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), - "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), - "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), - "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), - "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), - "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), - "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), - "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), - "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), - "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), - "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), - "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), - "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), - "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), - "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), - "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), - "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), - "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), - "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111) - : "l"(desc_a), + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), "l"(desc_b), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); + "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x224x32_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x64x32_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x224x32 TN F32+=E4M3*E5M2 -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -struct MMA_64x224x32_F32E4M3E5M2_RS_TN +// GMMA 64x96x32 TN S32+=U8*U8 +struct MMA_64x96x32_S32U8U8_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; using BRegisters = uint64_t[1]; - using CRegisters = float[112]; + using CRegisters = uint32_t[48]; CUTE_HOST_DEVICE static void - fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, uint64_t const& desc_b, - float & d000, float & d001, float & d002, float & d003, - float & d004, float & d005, float & d006, float & d007, - float & d008, float & d009, float & d010, float & d011, - float & d012, float & d013, float & d014, float & d015, - float & d016, float & d017, float & d018, float & d019, - float & d020, float & d021, float & d022, float & d023, - float & d024, float & d025, float & d026, float & d027, - float & d028, float & d029, float & d030, float & d031, - float & d032, float & d033, float & d034, float & d035, - float & d036, float & d037, float & d038, float & d039, - float & d040, float & d041, float & d042, float & d043, - float & d044, float & d045, float & d046, float & d047, - float & d048, float & d049, float & d050, float & d051, - float & d052, float & d053, float & d054, float & d055, - float & d056, float & d057, float & d058, float & d059, - float & d060, float & d061, float & d062, float & d063, - float & d064, float & d065, float & d066, float & d067, - float & d068, float & d069, float & d070, float & d071, - float & d072, float & d073, float & d074, float & d075, - float & d076, float & d077, float & d078, float & d079, - float & d080, float & d081, float & d082, float & d083, - float & d084, float & d085, float & d086, float & d087, - float & d088, float & d089, float & d090, float & d091, - float & d092, float & d093, float & d094, float & d095, - float & d096, float & d097, float & d098, float & d099, - float & d100, float & d101, float & d102, float & d103, - float & d104, float & d105, float & d106, float & d107, - float & d108, float & d109, float & d110, float & d111, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) @@ -40294,81 +12311,51 @@ struct MMA_64x224x32_F32E4M3E5M2_RS_TN asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %117, 0;\n" - "wgmma.mma_async.sync.aligned.m64n224k32.f32.e4m3.e5m2 " + "setp.ne.b32 p, %53, 0;\n" + "wgmma.mma_async.sync.aligned.m64n96k32.s32.u8.u8 " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " " %16, %17, %18, %19, %20, %21, %22, %23, " " %24, %25, %26, %27, %28, %29, %30, %31, " " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87, " - " %88, %89, %90, %91, %92, %93, %94, %95, " - " %96, %97, %98, %99, %100, %101, %102, %103, " - " %104, %105, %106, %107, %108, %109, %110, %111}," - "{%112, %113, %114, %115}," - " %116," - " p, %118, %119;\n" + " %40, %41, %42, %43, %44, %45, %46, %47}," + "{%48, %49, %50, %51}," + " %52," + " p;\n" "}\n" - : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), - "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), - "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), - "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), - "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), - "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), - "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), - "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), - "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), - "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), - "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), - "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), - "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), - "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), - "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), - "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), - "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), - "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), - "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), - "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), - "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), - "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), - "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), - "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), - "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), - "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), - "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), - "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111) - : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), "l"(desc_b), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); + "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x224x32_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x96x32_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x240x32 TN F16+=E4M3*E5M2 -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -struct MMA_64x240x32_F16E4M3E5M2_SS_TN +// GMMA 64x96x32 TN S32+=U8*U8 +struct MMA_64x96x32_S32U8U8_RS_TN_SATURATE { using DRegisters = void; - using ARegisters = uint64_t[1]; + using ARegisters = uint32_t[4]; using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[60]; + using CRegisters = uint32_t[48]; CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, uint64_t const& desc_b, uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, @@ -40382,29 +12369,24 @@ struct MMA_64x240x32_F16E4M3E5M2_SS_TN uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, - uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, - uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, - uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %62, 0;\n" - "wgmma.mma_async.sync.aligned.m64n240k32.f16.e4m3.e5m2 " + "setp.ne.b32 p, %53, 0;\n" + "wgmma.mma_async.sync.aligned.m64n96k32.s32.u8.u8.satfinite " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " " %16, %17, %18, %19, %20, %21, %22, %23, " " %24, %25, %26, %27, %28, %29, %30, %31, " " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59}," - " %60," - " %61," - " p, %63, %64;\n" + " %40, %41, %42, %43, %44, %45, %46, %47}," + "{%48, %49, %50, %51}," + " %52," + " p;\n" "}\n" : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), @@ -40417,34 +12399,25 @@ struct MMA_64x240x32_F16E4M3E5M2_SS_TN "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), - "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), - "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), - "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), - "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59) - : "l"(desc_a), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), "l"(desc_b), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); + "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x240x32_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x96x32_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x240x32 TN F16+=E4M3*E5M2 -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -struct MMA_64x240x32_F16E4M3E5M2_RS_TN +// GMMA 64x128x32 TN S32+=U8*U8 +struct MMA_64x128x32_S32U8U8_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[60]; + using CRegisters = uint32_t[64]; CUTE_HOST_DEVICE static void fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, @@ -40464,6 +12437,7 @@ struct MMA_64x240x32_F16E4M3E5M2_RS_TN uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) @@ -40471,8 +12445,8 @@ struct MMA_64x240x32_F16E4M3E5M2_RS_TN asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %65, 0;\n" - "wgmma.mma_async.sync.aligned.m64n240k32.f16.e4m3.e5m2 " + "setp.ne.b32 p, %69, 0;\n" + "wgmma.mma_async.sync.aligned.m64n128k32.s32.u8.u8 " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " " %16, %17, %18, %19, %20, %21, %22, %23, " @@ -40480,10 +12454,10 @@ struct MMA_64x240x32_F16E4M3E5M2_RS_TN " %32, %33, %34, %35, %36, %37, %38, %39, " " %40, %41, %42, %43, %44, %45, %46, %47, " " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59}," - "{%60, %61, %62, %63}," - " %64," - " p, %66, %67;\n" + " %56, %57, %58, %59, %60, %61, %62, %63}," + "{%64, %65, %66, %67}," + " %68," + " p;\n" "}\n" : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), @@ -40499,181 +12473,46 @@ struct MMA_64x240x32_F16E4M3E5M2_RS_TN "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), - "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59) + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) : "r"(a00), "r"(a01), "r"(a02), "r"(a03), "l"(desc_b), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x240x32_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x240x32 TN F32+=E4M3*E5M2 -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -struct MMA_64x240x32_F32E4M3E5M2_SS_TN -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = float[120]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - float & d000, float & d001, float & d002, float & d003, - float & d004, float & d005, float & d006, float & d007, - float & d008, float & d009, float & d010, float & d011, - float & d012, float & d013, float & d014, float & d015, - float & d016, float & d017, float & d018, float & d019, - float & d020, float & d021, float & d022, float & d023, - float & d024, float & d025, float & d026, float & d027, - float & d028, float & d029, float & d030, float & d031, - float & d032, float & d033, float & d034, float & d035, - float & d036, float & d037, float & d038, float & d039, - float & d040, float & d041, float & d042, float & d043, - float & d044, float & d045, float & d046, float & d047, - float & d048, float & d049, float & d050, float & d051, - float & d052, float & d053, float & d054, float & d055, - float & d056, float & d057, float & d058, float & d059, - float & d060, float & d061, float & d062, float & d063, - float & d064, float & d065, float & d066, float & d067, - float & d068, float & d069, float & d070, float & d071, - float & d072, float & d073, float & d074, float & d075, - float & d076, float & d077, float & d078, float & d079, - float & d080, float & d081, float & d082, float & d083, - float & d084, float & d085, float & d086, float & d087, - float & d088, float & d089, float & d090, float & d091, - float & d092, float & d093, float & d094, float & d095, - float & d096, float & d097, float & d098, float & d099, - float & d100, float & d101, float & d102, float & d103, - float & d104, float & d105, float & d106, float & d107, - float & d108, float & d109, float & d110, float & d111, - float & d112, float & d113, float & d114, float & d115, - float & d116, float & d117, float & d118, float & d119, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %122, 0;\n" - "wgmma.mma_async.sync.aligned.m64n240k32.f32.e4m3.e5m2 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87, " - " %88, %89, %90, %91, %92, %93, %94, %95, " - " %96, %97, %98, %99, %100, %101, %102, %103, " - " %104, %105, %106, %107, %108, %109, %110, %111, " - " %112, %113, %114, %115, %116, %117, %118, %119}," - " %120," - " %121," - " p, %123, %124;\n" - "}\n" - : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), - "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), - "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), - "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), - "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), - "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), - "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), - "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), - "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), - "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), - "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), - "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), - "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), - "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), - "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), - "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), - "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), - "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), - "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), - "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), - "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), - "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), - "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), - "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), - "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), - "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), - "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), - "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), - "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), - "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119) - : "l"(desc_a), - "l"(desc_b), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); + "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x240x32_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x128x32_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x240x32 TN F32+=E4M3*E5M2 -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -struct MMA_64x240x32_F32E4M3E5M2_RS_TN +// GMMA 64x128x32 TN S32+=U8*U8 +struct MMA_64x128x32_S32U8U8_RS_TN_SATURATE { using DRegisters = void; using ARegisters = uint32_t[4]; using BRegisters = uint64_t[1]; - using CRegisters = float[120]; + using CRegisters = uint32_t[64]; CUTE_HOST_DEVICE static void - fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, uint64_t const& desc_b, - float & d000, float & d001, float & d002, float & d003, - float & d004, float & d005, float & d006, float & d007, - float & d008, float & d009, float & d010, float & d011, - float & d012, float & d013, float & d014, float & d015, - float & d016, float & d017, float & d018, float & d019, - float & d020, float & d021, float & d022, float & d023, - float & d024, float & d025, float & d026, float & d027, - float & d028, float & d029, float & d030, float & d031, - float & d032, float & d033, float & d034, float & d035, - float & d036, float & d037, float & d038, float & d039, - float & d040, float & d041, float & d042, float & d043, - float & d044, float & d045, float & d046, float & d047, - float & d048, float & d049, float & d050, float & d051, - float & d052, float & d053, float & d054, float & d055, - float & d056, float & d057, float & d058, float & d059, - float & d060, float & d061, float & d062, float & d063, - float & d064, float & d065, float & d066, float & d067, - float & d068, float & d069, float & d070, float & d071, - float & d072, float & d073, float & d074, float & d075, - float & d076, float & d077, float & d078, float & d079, - float & d080, float & d081, float & d082, float & d083, - float & d084, float & d085, float & d086, float & d087, - float & d088, float & d089, float & d090, float & d091, - float & d092, float & d093, float & d094, float & d095, - float & d096, float & d097, float & d098, float & d099, - float & d100, float & d101, float & d102, float & d103, - float & d104, float & d105, float & d106, float & d107, - float & d108, float & d109, float & d110, float & d111, - float & d112, float & d113, float & d114, float & d115, - float & d116, float & d117, float & d118, float & d119, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) @@ -40681,8 +12520,8 @@ struct MMA_64x240x32_F32E4M3E5M2_RS_TN asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %125, 0;\n" - "wgmma.mma_async.sync.aligned.m64n240k32.f32.e4m3.e5m2 " + "setp.ne.b32 p, %69, 0;\n" + "wgmma.mma_async.sync.aligned.m64n128k32.s32.u8.u8.satfinite " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " " %16, %17, %18, %19, %20, %21, %22, %23, " @@ -40690,74 +12529,48 @@ struct MMA_64x240x32_F32E4M3E5M2_RS_TN " %32, %33, %34, %35, %36, %37, %38, %39, " " %40, %41, %42, %43, %44, %45, %46, %47, " " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87, " - " %88, %89, %90, %91, %92, %93, %94, %95, " - " %96, %97, %98, %99, %100, %101, %102, %103, " - " %104, %105, %106, %107, %108, %109, %110, %111, " - " %112, %113, %114, %115, %116, %117, %118, %119}," - "{%120, %121, %122, %123}," - " %124," - " p, %126, %127;\n" + " %56, %57, %58, %59, %60, %61, %62, %63}," + "{%64, %65, %66, %67}," + " %68," + " p;\n" "}\n" - : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), - "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), - "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), - "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), - "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), - "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), - "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), - "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), - "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), - "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), - "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), - "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), - "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), - "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), - "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), - "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), - "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), - "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), - "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), - "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), - "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), - "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), - "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), - "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), - "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), - "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), - "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), - "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), - "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), - "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119) - : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), "l"(desc_b), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); + "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x240x32_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x128x32_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -// GMMA 64x256x32 TN F16+=E4M3*E5M2 -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -struct MMA_64x256x32_F16E4M3E5M2_SS_TN +// GMMA 64x192x32 TN S32+=U8*U8 +struct MMA_64x192x32_S32U8U8_RS_TN { using DRegisters = void; - using ARegisters = uint64_t[1]; + using ARegisters = uint32_t[4]; using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[64]; + using CRegisters = uint32_t[96]; CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, uint64_t const& desc_b, uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, @@ -40775,15 +12588,23 @@ struct MMA_64x256x32_F16E4M3E5M2_SS_TN uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + uint32_t & d88, uint32_t & d89, uint32_t & d90, uint32_t & d91, + uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %66, 0;\n" - "wgmma.mma_async.sync.aligned.m64n256k32.f16.e4m3.e5m2 " + "setp.ne.b32 p, %101, 0;\n" + "wgmma.mma_async.sync.aligned.m64n192k32.s32.u8.u8 " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " " %16, %17, %18, %19, %20, %21, %22, %23, " @@ -40791,10 +12612,14 @@ struct MMA_64x256x32_F16E4M3E5M2_SS_TN " %32, %33, %34, %35, %36, %37, %38, %39, " " %40, %41, %42, %43, %44, %45, %46, %47, " " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63}," - " %64," - " %65," - " p, %67, %68;\n" + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + "{%96, %97, %98, %99}," + " %100," + " p;\n" "}\n" : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), @@ -40811,29 +12636,33 @@ struct MMA_64x256x32_F16E4M3E5M2_SS_TN "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), - "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) - : "l"(desc_a), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87), + "+r"(d88), "+r"(d89), "+r"(d90), "+r"(d91), + "+r"(d92), "+r"(d93), "+r"(d94), "+r"(d95) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), "l"(desc_b), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); + "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x256x32_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x192x32_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// GMMA 64x256x32 TN F16+=E4M3*E5M2 -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -struct MMA_64x256x32_F16E4M3E5M2_RS_TN +// GMMA 64x192x32 TN S32+=U8*U8 +struct MMA_64x192x32_S32U8U8_RS_TN_SATURATE { using DRegisters = void; using ARegisters = uint32_t[4]; using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[64]; + using CRegisters = uint32_t[96]; CUTE_HOST_DEVICE static void fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, @@ -40854,6 +12683,14 @@ struct MMA_64x256x32_F16E4M3E5M2_RS_TN uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + uint32_t & d88, uint32_t & d89, uint32_t & d90, uint32_t & d91, + uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) @@ -40861,8 +12698,8 @@ struct MMA_64x256x32_F16E4M3E5M2_RS_TN asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %69, 0;\n" - "wgmma.mma_async.sync.aligned.m64n256k32.f16.e4m3.e5m2 " + "setp.ne.b32 p, %101, 0;\n" + "wgmma.mma_async.sync.aligned.m64n192k32.s32.u8.u8.satfinite " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " " %16, %17, %18, %19, %20, %21, %22, %23, " @@ -40870,10 +12707,14 @@ struct MMA_64x256x32_F16E4M3E5M2_RS_TN " %32, %33, %34, %35, %36, %37, %38, %39, " " %40, %41, %42, %43, %44, %45, %46, %47, " " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63}," - "{%64, %65, %66, %67}," - " %68," - " p, %70, %71;\n" + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + "{%96, %97, %98, %99}," + " %100," + " p;\n" "}\n" : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), @@ -40890,184 +12731,184 @@ struct MMA_64x256x32_F16E4M3E5M2_RS_TN "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), - "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87), + "+r"(d88), "+r"(d89), "+r"(d90), "+r"(d91), + "+r"(d92), "+r"(d93), "+r"(d94), "+r"(d95) : "r"(a00), "r"(a01), "r"(a02), "r"(a03), "l"(desc_b), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); + "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x256x32_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x192x32_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// GMMA 64x256x32 TN F32+=E4M3*E5M2 -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -struct MMA_64x256x32_F32E4M3E5M2_SS_TN +// GMMA 64x256x32 TN S32+=U8*U8 +struct MMA_64x256x32_S32U8U8_RS_TN { using DRegisters = void; - using ARegisters = uint64_t[1]; + using ARegisters = uint32_t[4]; using BRegisters = uint64_t[1]; - using CRegisters = float[128]; + using CRegisters = uint32_t[128]; CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, uint64_t const& desc_b, - float & d000, float & d001, float & d002, float & d003, - float & d004, float & d005, float & d006, float & d007, - float & d008, float & d009, float & d010, float & d011, - float & d012, float & d013, float & d014, float & d015, - float & d016, float & d017, float & d018, float & d019, - float & d020, float & d021, float & d022, float & d023, - float & d024, float & d025, float & d026, float & d027, - float & d028, float & d029, float & d030, float & d031, - float & d032, float & d033, float & d034, float & d035, - float & d036, float & d037, float & d038, float & d039, - float & d040, float & d041, float & d042, float & d043, - float & d044, float & d045, float & d046, float & d047, - float & d048, float & d049, float & d050, float & d051, - float & d052, float & d053, float & d054, float & d055, - float & d056, float & d057, float & d058, float & d059, - float & d060, float & d061, float & d062, float & d063, - float & d064, float & d065, float & d066, float & d067, - float & d068, float & d069, float & d070, float & d071, - float & d072, float & d073, float & d074, float & d075, - float & d076, float & d077, float & d078, float & d079, - float & d080, float & d081, float & d082, float & d083, - float & d084, float & d085, float & d086, float & d087, - float & d088, float & d089, float & d090, float & d091, - float & d092, float & d093, float & d094, float & d095, - float & d096, float & d097, float & d098, float & d099, - float & d100, float & d101, float & d102, float & d103, - float & d104, float & d105, float & d106, float & d107, - float & d108, float & d109, float & d110, float & d111, - float & d112, float & d113, float & d114, float & d115, - float & d116, float & d117, float & d118, float & d119, - float & d120, float & d121, float & d122, float & d123, - float & d124, float & d125, float & d126, float & d127, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + uint32_t & d120, uint32_t & d121, uint32_t & d122, uint32_t & d123, + uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %130, 0;\n" - "wgmma.mma_async.sync.aligned.m64n256k32.f32.e4m3.e5m2 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87, " - " %88, %89, %90, %91, %92, %93, %94, %95, " - " %96, %97, %98, %99, %100, %101, %102, %103, " - " %104, %105, %106, %107, %108, %109, %110, %111, " - " %112, %113, %114, %115, %116, %117, %118, %119, " - " %120, %121, %122, %123, %124, %125, %126, %127}," - " %128," - " %129," - " p, %131, %132;\n" - "}\n" - : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), - "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), - "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), - "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), - "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), - "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), - "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), - "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), - "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), - "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), - "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), - "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), - "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), - "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), - "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), - "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), - "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), - "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), - "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), - "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), - "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), - "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), - "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), - "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), - "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), - "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), - "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), - "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), - "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), - "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), - "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123), - "+f"(d124), "+f"(d125), "+f"(d126), "+f"(d127) - : "l"(desc_a), +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %133, 0;\n" + "wgmma.mma_async.sync.aligned.m64n256k32.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + "{%128, %129, %130, %131}," + " %132," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119), + "+r"(d120), "+r"(d121), "+r"(d122), "+r"(d123), + "+r"(d124), "+r"(d125), "+r"(d126), "+r"(d127) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), "l"(desc_b), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); + "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x256x32_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x256x32_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// GMMA 64x256x32 TN F32+=E4M3*E5M2 -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -struct MMA_64x256x32_F32E4M3E5M2_RS_TN +// GMMA 64x256x32 TN S32+=U8*U8 +struct MMA_64x256x32_S32U8U8_RS_TN_SATURATE { using DRegisters = void; using ARegisters = uint32_t[4]; using BRegisters = uint64_t[1]; - using CRegisters = float[128]; + using CRegisters = uint32_t[128]; CUTE_HOST_DEVICE static void fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, uint64_t const& desc_b, - float & d000, float & d001, float & d002, float & d003, - float & d004, float & d005, float & d006, float & d007, - float & d008, float & d009, float & d010, float & d011, - float & d012, float & d013, float & d014, float & d015, - float & d016, float & d017, float & d018, float & d019, - float & d020, float & d021, float & d022, float & d023, - float & d024, float & d025, float & d026, float & d027, - float & d028, float & d029, float & d030, float & d031, - float & d032, float & d033, float & d034, float & d035, - float & d036, float & d037, float & d038, float & d039, - float & d040, float & d041, float & d042, float & d043, - float & d044, float & d045, float & d046, float & d047, - float & d048, float & d049, float & d050, float & d051, - float & d052, float & d053, float & d054, float & d055, - float & d056, float & d057, float & d058, float & d059, - float & d060, float & d061, float & d062, float & d063, - float & d064, float & d065, float & d066, float & d067, - float & d068, float & d069, float & d070, float & d071, - float & d072, float & d073, float & d074, float & d075, - float & d076, float & d077, float & d078, float & d079, - float & d080, float & d081, float & d082, float & d083, - float & d084, float & d085, float & d086, float & d087, - float & d088, float & d089, float & d090, float & d091, - float & d092, float & d093, float & d094, float & d095, - float & d096, float & d097, float & d098, float & d099, - float & d100, float & d101, float & d102, float & d103, - float & d104, float & d105, float & d106, float & d107, - float & d108, float & d109, float & d110, float & d111, - float & d112, float & d113, float & d114, float & d115, - float & d116, float & d117, float & d118, float & d119, - float & d120, float & d121, float & d122, float & d123, - float & d124, float & d125, float & d126, float & d127, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + uint32_t & d120, uint32_t & d121, uint32_t & d122, uint32_t & d123, + uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) @@ -41076,7 +12917,7 @@ struct MMA_64x256x32_F32E4M3E5M2_RS_TN "{\n" ".reg .pred p;\n" "setp.ne.b32 p, %133, 0;\n" - "wgmma.mma_async.sync.aligned.m64n256k32.f32.e4m3.e5m2 " + "wgmma.mma_async.sync.aligned.m64n256k32.s32.u8.u8.satfinite " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " " %16, %17, %18, %19, %20, %21, %22, %23, " @@ -41095,57 +12936,57 @@ struct MMA_64x256x32_F32E4M3E5M2_RS_TN " %120, %121, %122, %123, %124, %125, %126, %127}," "{%128, %129, %130, %131}," " %132," - " p, %134, %135;\n" + " p;\n" "}\n" - : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), - "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), - "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), - "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), - "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), - "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), - "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), - "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), - "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), - "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), - "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), - "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), - "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), - "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), - "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), - "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), - "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), - "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), - "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), - "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), - "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), - "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), - "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), - "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), - "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), - "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), - "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), - "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), - "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), - "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), - "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123), - "+f"(d124), "+f"(d125), "+f"(d126), "+f"(d127) + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119), + "+r"(d120), "+r"(d121), "+r"(d122), "+r"(d123), + "+r"(d124), "+r"(d125), "+r"(d126), "+r"(d127) : "r"(a000), "r"(a001), "r"(a002), "r"(a003), "l"(desc_b), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); + "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x256x32_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x256x32_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// GMMA 64x8x32 TN F16+=E5M2*E4M3 +// GMMA 64x8x32 TN F16+=E4M3*E4M3 template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct MMA_64x8x32_F16E5M2E4M3_SS_TN +struct MMA_64x8x32_F16E4M3E4M3_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -41164,7 +13005,7 @@ struct MMA_64x8x32_F16E5M2E4M3_SS_TN "{\n" ".reg .pred p;\n" "setp.ne.b32 p, %4, 0;\n" - "wgmma.mma_async.sync.aligned.m64n8k32.f16.e5m2.e4m3 " + "wgmma.mma_async.sync.aligned.m64n8k32.f16.e4m3.e4m3 " "{%0, %1}," " %2," " %3," @@ -41175,19 +13016,19 @@ struct MMA_64x8x32_F16E5M2E4M3_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x8x32_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x8x32_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// GMMA 64x8x32 TN F16+=E5M2*E4M3 +// GMMA 64x8x32 TN F16+=E4M3*E4M3 template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct MMA_64x8x32_F16E5M2E4M3_RS_TN +struct MMA_64x8x32_F16E4M3E4M3_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -41206,7 +13047,7 @@ struct MMA_64x8x32_F16E5M2E4M3_RS_TN "{\n" ".reg .pred p;\n" "setp.ne.b32 p, %7, 0;\n" - "wgmma.mma_async.sync.aligned.m64n8k32.f16.e5m2.e4m3 " + "wgmma.mma_async.sync.aligned.m64n8k32.f16.e4m3.e4m3 " "{%0, %1}," "{%2, %3, %4, %5}," " %6," @@ -41217,19 +13058,19 @@ struct MMA_64x8x32_F16E5M2E4M3_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x8x32_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x8x32_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// GMMA 64x8x32 TN F32+=E5M2*E4M3 +// GMMA 64x8x32 TN F32+=E4M3*E4M3 template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct MMA_64x8x32_F32E5M2E4M3_SS_TN +struct MMA_64x8x32_F32E4M3E4M3_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -41248,7 +13089,7 @@ struct MMA_64x8x32_F32E5M2E4M3_SS_TN "{\n" ".reg .pred p;\n" "setp.ne.b32 p, %6, 0;\n" - "wgmma.mma_async.sync.aligned.m64n8k32.f32.e5m2.e4m3 " + "wgmma.mma_async.sync.aligned.m64n8k32.f32.e4m3.e4m3 " "{%0, %1, %2, %3}," " %4," " %5," @@ -41259,19 +13100,19 @@ struct MMA_64x8x32_F32E5M2E4M3_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x8x32_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x8x32_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// GMMA 64x8x32 TN F32+=E5M2*E4M3 +// GMMA 64x8x32 TN F32+=E4M3*E4M3 template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct MMA_64x8x32_F32E5M2E4M3_RS_TN +struct MMA_64x8x32_F32E4M3E4M3_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -41290,7 +13131,7 @@ struct MMA_64x8x32_F32E5M2E4M3_RS_TN "{\n" ".reg .pred p;\n" "setp.ne.b32 p, %9, 0;\n" - "wgmma.mma_async.sync.aligned.m64n8k32.f32.e5m2.e4m3 " + "wgmma.mma_async.sync.aligned.m64n8k32.f32.e4m3.e4m3 " "{%0, %1, %2, %3}," "{%4, %5, %6, %7}," " %8," @@ -41301,19 +13142,19 @@ struct MMA_64x8x32_F32E5M2E4M3_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x8x32_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x8x32_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// GMMA 64x16x32 TN F16+=E5M2*E4M3 +// GMMA 64x16x32 TN F16+=E4M3*E4M3 template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct MMA_64x16x32_F16E5M2E4M3_SS_TN +struct MMA_64x16x32_F16E4M3E4M3_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -41332,7 +13173,7 @@ struct MMA_64x16x32_F16E5M2E4M3_SS_TN "{\n" ".reg .pred p;\n" "setp.ne.b32 p, %6, 0;\n" - "wgmma.mma_async.sync.aligned.m64n16k32.f16.e5m2.e4m3 " + "wgmma.mma_async.sync.aligned.m64n16k32.f16.e4m3.e4m3 " "{%0, %1, %2, %3}," " %4," " %5," @@ -41343,19 +13184,19 @@ struct MMA_64x16x32_F16E5M2E4M3_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x16x32_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x16x32_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// GMMA 64x16x32 TN F16+=E5M2*E4M3 +// GMMA 64x16x32 TN F16+=E4M3*E4M3 template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct MMA_64x16x32_F16E5M2E4M3_RS_TN +struct MMA_64x16x32_F16E4M3E4M3_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -41374,7 +13215,7 @@ struct MMA_64x16x32_F16E5M2E4M3_RS_TN "{\n" ".reg .pred p;\n" "setp.ne.b32 p, %9, 0;\n" - "wgmma.mma_async.sync.aligned.m64n16k32.f16.e5m2.e4m3 " + "wgmma.mma_async.sync.aligned.m64n16k32.f16.e4m3.e4m3 " "{%0, %1, %2, %3}," "{%4, %5, %6, %7}," " %8," @@ -41385,19 +13226,19 @@ struct MMA_64x16x32_F16E5M2E4M3_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x16x32_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x16x32_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// GMMA 64x16x32 TN F32+=E5M2*E4M3 +// GMMA 64x16x32 TN F32+=E4M3*E4M3 template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct MMA_64x16x32_F32E5M2E4M3_SS_TN +struct MMA_64x16x32_F32E4M3E4M3_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -41417,7 +13258,7 @@ struct MMA_64x16x32_F32E5M2E4M3_SS_TN "{\n" ".reg .pred p;\n" "setp.ne.b32 p, %10, 0;\n" - "wgmma.mma_async.sync.aligned.m64n16k32.f32.e5m2.e4m3 " + "wgmma.mma_async.sync.aligned.m64n16k32.f32.e4m3.e4m3 " "{%0, %1, %2, %3, %4, %5, %6, %7}," " %8," " %9," @@ -41429,19 +13270,19 @@ struct MMA_64x16x32_F32E5M2E4M3_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x16x32_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x16x32_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// GMMA 64x16x32 TN F32+=E5M2*E4M3 +// GMMA 64x16x32 TN F32+=E4M3*E4M3 template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct MMA_64x16x32_F32E5M2E4M3_RS_TN +struct MMA_64x16x32_F32E4M3E4M3_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -41461,7 +13302,7 @@ struct MMA_64x16x32_F32E5M2E4M3_RS_TN "{\n" ".reg .pred p;\n" "setp.ne.b32 p, %13, 0;\n" - "wgmma.mma_async.sync.aligned.m64n16k32.f32.e5m2.e4m3 " + "wgmma.mma_async.sync.aligned.m64n16k32.f32.e4m3.e4m3 " "{%0, %1, %2, %3, %4, %5, %6, %7}," "{%8, %9, %10, %11}," " %12," @@ -41473,19 +13314,19 @@ struct MMA_64x16x32_F32E5M2E4M3_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x16x32_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x16x32_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// GMMA 64x32x32 TN F16+=E5M2*E4M3 +// GMMA 64x32x32 TN F16+=E4M3*E4M3 template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct MMA_64x32x32_F16E5M2E4M3_SS_TN +struct MMA_64x32x32_F16E4M3E4M3_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -41505,7 +13346,7 @@ struct MMA_64x32x32_F16E5M2E4M3_SS_TN "{\n" ".reg .pred p;\n" "setp.ne.b32 p, %10, 0;\n" - "wgmma.mma_async.sync.aligned.m64n32k32.f16.e5m2.e4m3 " + "wgmma.mma_async.sync.aligned.m64n32k32.f16.e4m3.e4m3 " "{%0, %1, %2, %3, %4, %5, %6, %7}," " %8," " %9," @@ -41517,19 +13358,19 @@ struct MMA_64x32x32_F16E5M2E4M3_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x32x32_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x32x32_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// GMMA 64x32x32 TN F16+=E5M2*E4M3 +// GMMA 64x32x32 TN F16+=E4M3*E4M3 template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct MMA_64x32x32_F16E5M2E4M3_RS_TN +struct MMA_64x32x32_F16E4M3E4M3_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -41549,7 +13390,7 @@ struct MMA_64x32x32_F16E5M2E4M3_RS_TN "{\n" ".reg .pred p;\n" "setp.ne.b32 p, %13, 0;\n" - "wgmma.mma_async.sync.aligned.m64n32k32.f16.e5m2.e4m3 " + "wgmma.mma_async.sync.aligned.m64n32k32.f16.e4m3.e4m3 " "{%0, %1, %2, %3, %4, %5, %6, %7}," "{%8, %9, %10, %11}," " %12," @@ -41561,19 +13402,19 @@ struct MMA_64x32x32_F16E5M2E4M3_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x32x32_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x32x32_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// GMMA 64x32x32 TN F32+=E5M2*E4M3 +// GMMA 64x32x32 TN F32+=E4M3*E4M3 template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct MMA_64x32x32_F32E5M2E4M3_SS_TN +struct MMA_64x32x32_F32E4M3E4M3_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -41595,7 +13436,7 @@ struct MMA_64x32x32_F32E5M2E4M3_SS_TN "{\n" ".reg .pred p;\n" "setp.ne.b32 p, %18, 0;\n" - "wgmma.mma_async.sync.aligned.m64n32k32.f32.e5m2.e4m3 " + "wgmma.mma_async.sync.aligned.m64n32k32.f32.e4m3.e4m3 " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15}," " %16," @@ -41610,19 +13451,19 @@ struct MMA_64x32x32_F32E5M2E4M3_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x32x32_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x32x32_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// GMMA 64x32x32 TN F32+=E5M2*E4M3 +// GMMA 64x32x32 TN F32+=E4M3*E4M3 template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct MMA_64x32x32_F32E5M2E4M3_RS_TN +struct MMA_64x32x32_F32E4M3E4M3_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -41644,7 +13485,7 @@ struct MMA_64x32x32_F32E5M2E4M3_RS_TN "{\n" ".reg .pred p;\n" "setp.ne.b32 p, %21, 0;\n" - "wgmma.mma_async.sync.aligned.m64n32k32.f32.e5m2.e4m3 " + "wgmma.mma_async.sync.aligned.m64n32k32.f32.e4m3.e4m3 " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15}," "{%16, %17, %18, %19}," @@ -41659,229 +13500,19 @@ struct MMA_64x32x32_F32E5M2E4M3_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x32x32_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x48x32 TN F16+=E5M2*E4M3 -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -struct MMA_64x48x32_F16E5M2E4M3_SS_TN -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[12]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %14, 0;\n" - "wgmma.mma_async.sync.aligned.m64n48k32.f16.e5m2.e4m3 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11}," - " %12," - " %13," - " p, %15, %16;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11) - : "l"(desc_a), - "l"(desc_b), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x48x32_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x48x32 TN F16+=E5M2*E4M3 -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -struct MMA_64x48x32_F16E5M2E4M3_RS_TN -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[12]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %17, 0;\n" - "wgmma.mma_async.sync.aligned.m64n48k32.f16.e5m2.e4m3 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11}," - "{%12, %13, %14, %15}," - " %16," - " p, %18, %19;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), - "l"(desc_b), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x48x32_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x48x32 TN F32+=E5M2*E4M3 -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -struct MMA_64x48x32_F32E5M2E4M3_SS_TN -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = float[24]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - float & d00, float & d01, float & d02, float & d03, - float & d04, float & d05, float & d06, float & d07, - float & d08, float & d09, float & d10, float & d11, - float & d12, float & d13, float & d14, float & d15, - float & d16, float & d17, float & d18, float & d19, - float & d20, float & d21, float & d22, float & d23, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %26, 0;\n" - "wgmma.mma_async.sync.aligned.m64n48k32.f32.e5m2.e4m3 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23}," - " %24," - " %25," - " p, %27, %28;\n" - "}\n" - : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), - "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), - "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), - "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), - "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), - "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23) - : "l"(desc_a), - "l"(desc_b), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x48x32_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x48x32 TN F32+=E5M2*E4M3 -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -struct MMA_64x48x32_F32E5M2E4M3_RS_TN -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using BRegisters = uint64_t[1]; - using CRegisters = float[24]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, - uint64_t const& desc_b, - float & d00, float & d01, float & d02, float & d03, - float & d04, float & d05, float & d06, float & d07, - float & d08, float & d09, float & d10, float & d11, - float & d12, float & d13, float & d14, float & d15, - float & d16, float & d17, float & d18, float & d19, - float & d20, float & d21, float & d22, float & d23, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %29, 0;\n" - "wgmma.mma_async.sync.aligned.m64n48k32.f32.e5m2.e4m3 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23}," - "{%24, %25, %26, %27}," - " %28," - " p, %30, %31;\n" - "}\n" - : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), - "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), - "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), - "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), - "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), - "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), - "l"(desc_b), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x48x32_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x32x32_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -// GMMA 64x64x32 TN F16+=E5M2*E4M3 +// GMMA 64x64x32 TN F16+=E4M3*E4M3 template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct MMA_64x64x32_F16E5M2E4M3_SS_TN +struct MMA_64x64x32_F16E4M3E4M3_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -41903,7 +13534,7 @@ struct MMA_64x64x32_F16E5M2E4M3_SS_TN "{\n" ".reg .pred p;\n" "setp.ne.b32 p, %18, 0;\n" - "wgmma.mma_async.sync.aligned.m64n64k32.f16.e5m2.e4m3 " + "wgmma.mma_async.sync.aligned.m64n64k32.f16.e4m3.e4m3 " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15}," " %16," @@ -41918,19 +13549,19 @@ struct MMA_64x64x32_F16E5M2E4M3_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x64x32_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x64x32_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// GMMA 64x64x32 TN F16+=E5M2*E4M3 +// GMMA 64x64x32 TN F16+=E4M3*E4M3 template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct MMA_64x64x32_F16E5M2E4M3_RS_TN +struct MMA_64x64x32_F16E4M3E4M3_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -41952,7 +13583,7 @@ struct MMA_64x64x32_F16E5M2E4M3_RS_TN "{\n" ".reg .pred p;\n" "setp.ne.b32 p, %21, 0;\n" - "wgmma.mma_async.sync.aligned.m64n64k32.f16.e5m2.e4m3 " + "wgmma.mma_async.sync.aligned.m64n64k32.f16.e4m3.e4m3 " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15}," "{%16, %17, %18, %19}," @@ -41967,19 +13598,19 @@ struct MMA_64x64x32_F16E5M2E4M3_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x64x32_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x64x32_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// GMMA 64x64x32 TN F32+=E5M2*E4M3 +// GMMA 64x64x32 TN F32+=E4M3*E4M3 template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct MMA_64x64x32_F32E5M2E4M3_SS_TN +struct MMA_64x64x32_F32E4M3E4M3_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -42005,7 +13636,7 @@ struct MMA_64x64x32_F32E5M2E4M3_SS_TN "{\n" ".reg .pred p;\n" "setp.ne.b32 p, %34, 0;\n" - "wgmma.mma_async.sync.aligned.m64n64k32.f32.e5m2.e4m3 " + "wgmma.mma_async.sync.aligned.m64n64k32.f32.e4m3.e4m3 " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " " %16, %17, %18, %19, %20, %21, %22, %23, " @@ -42026,19 +13657,19 @@ struct MMA_64x64x32_F32E5M2E4M3_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x64x32_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x64x32_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// GMMA 64x64x32 TN F32+=E5M2*E4M3 +// GMMA 64x64x32 TN F32+=E4M3*E4M3 template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct MMA_64x64x32_F32E5M2E4M3_RS_TN +struct MMA_64x64x32_F32E4M3E4M3_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -42064,7 +13695,7 @@ struct MMA_64x64x32_F32E5M2E4M3_RS_TN "{\n" ".reg .pred p;\n" "setp.ne.b32 p, %37, 0;\n" - "wgmma.mma_async.sync.aligned.m64n64k32.f32.e5m2.e4m3 " + "wgmma.mma_async.sync.aligned.m64n64k32.f32.e4m3.e4m3 " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " " %16, %17, %18, %19, %20, %21, %22, %23, " @@ -42085,25 +13716,24 @@ struct MMA_64x64x32_F32E5M2E4M3_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x64x32_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x64x32_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x80x32 TN F16+=E5M2*E4M3 +// GMMA 64x96x32 TN F16+=E4M3*E4M3 template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct MMA_64x80x32_F16E5M2E4M3_SS_TN +struct MMA_64x96x32_F16E4M3E4M3_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[20]; + using CRegisters = uint32_t[24]; CUTE_HOST_DEVICE static void fma(uint64_t const& desc_a, @@ -42113,6 +13743,7 @@ struct MMA_64x80x32_F16E5M2E4M3_SS_TN uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) @@ -42120,44 +13751,43 @@ struct MMA_64x80x32_F16E5M2E4M3_SS_TN asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %22, 0;\n" - "wgmma.mma_async.sync.aligned.m64n80k32.f16.e5m2.e4m3 " + "setp.ne.b32 p, %26, 0;\n" + "wgmma.mma_async.sync.aligned.m64n96k32.f16.e4m3.e4m3 " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19}," - " %20," - " %21," - " p, %23, %24;\n" + " %16, %17, %18, %19, %20, %21, %22, %23}," + " %24," + " %25," + " p, %27, %28;\n" "}\n" : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19) + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) : "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x80x32_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x96x32_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x80x32 TN F16+=E5M2*E4M3 +// GMMA 64x96x32 TN F16+=E4M3*E4M3 template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct MMA_64x80x32_F16E5M2E4M3_RS_TN +struct MMA_64x96x32_F16E4M3E4M3_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[20]; + using CRegisters = uint32_t[24]; CUTE_HOST_DEVICE static void fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, @@ -42167,6 +13797,7 @@ struct MMA_64x80x32_F16E5M2E4M3_RS_TN uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) @@ -42174,44 +13805,43 @@ struct MMA_64x80x32_F16E5M2E4M3_RS_TN asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %25, 0;\n" - "wgmma.mma_async.sync.aligned.m64n80k32.f16.e5m2.e4m3 " + "setp.ne.b32 p, %29, 0;\n" + "wgmma.mma_async.sync.aligned.m64n96k32.f16.e4m3.e4m3 " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19}," - "{%20, %21, %22, %23}," - " %24," - " p, %26, %27;\n" + " %16, %17, %18, %19, %20, %21, %22, %23}," + "{%24, %25, %26, %27}," + " %28," + " p, %30, %31;\n" "}\n" : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19) + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) : "r"(a00), "r"(a01), "r"(a02), "r"(a03), "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x80x32_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x96x32_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x80x32 TN F32+=E5M2*E4M3 +// GMMA 64x96x32 TN F32+=E4M3*E4M3 template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct MMA_64x80x32_F32E5M2E4M3_SS_TN +struct MMA_64x96x32_F32E4M3E4M3_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = float[40]; + using CRegisters = float[48]; CUTE_HOST_DEVICE static void fma(uint64_t const& desc_a, @@ -42226,6 +13856,8 @@ struct MMA_64x80x32_F32E5M2E4M3_SS_TN float & d28, float & d29, float & d30, float & d31, float & d32, float & d33, float & d34, float & d35, float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) @@ -42233,16 +13865,17 @@ struct MMA_64x80x32_F32E5M2E4M3_SS_TN asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %42, 0;\n" - "wgmma.mma_async.sync.aligned.m64n80k32.f32.e5m2.e4m3 " + "setp.ne.b32 p, %50, 0;\n" + "wgmma.mma_async.sync.aligned.m64n96k32.f32.e4m3.e4m3 " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " " %16, %17, %18, %19, %20, %21, %22, %23, " " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39}," - " %40," - " %41," - " p, %43, %44;\n" + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + " %48," + " %49," + " p, %51, %52;\n" "}\n" : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), @@ -42253,31 +13886,31 @@ struct MMA_64x80x32_F32E5M2E4M3_SS_TN "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), - "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39) + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47) : "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x80x32_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x96x32_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x80x32 TN F32+=E5M2*E4M3 +// GMMA 64x96x32 TN F32+=E4M3*E4M3 template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct MMA_64x80x32_F32E5M2E4M3_RS_TN +struct MMA_64x96x32_F32E4M3E4M3_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; using BRegisters = uint64_t[1]; - using CRegisters = float[40]; + using CRegisters = float[48]; CUTE_HOST_DEVICE static void fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, @@ -42292,6 +13925,8 @@ struct MMA_64x80x32_F32E5M2E4M3_RS_TN float & d28, float & d29, float & d30, float & d31, float & d32, float & d33, float & d34, float & d35, float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) @@ -42299,16 +13934,17 @@ struct MMA_64x80x32_F32E5M2E4M3_RS_TN asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %45, 0;\n" - "wgmma.mma_async.sync.aligned.m64n80k32.f32.e5m2.e4m3 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39}," - "{%40, %41, %42, %43}," - " %44," - " p, %46, %47;\n" + "setp.ne.b32 p, %53, 0;\n" + "wgmma.mma_async.sync.aligned.m64n96k32.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + "{%48, %49, %50, %51}," + " %52," + " p, %54, %55;\n" "}\n" : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), @@ -42319,30 +13955,31 @@ struct MMA_64x80x32_F32E5M2E4M3_RS_TN "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), - "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39) + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47) : "r"(a00), "r"(a01), "r"(a02), "r"(a03), "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x80x32_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x96x32_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -// GMMA 64x96x32 TN F16+=E5M2*E4M3 +// GMMA 64x128x32 TN F16+=E4M3*E4M3 template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct MMA_64x96x32_F16E5M2E4M3_SS_TN +struct MMA_64x128x32_F16E4M3E4M3_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[24]; + using CRegisters = uint32_t[32]; CUTE_HOST_DEVICE static void fma(uint64_t const& desc_a, @@ -42353,6 +13990,8 @@ struct MMA_64x96x32_F16E5M2E4M3_SS_TN uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) @@ -42360,43 +13999,46 @@ struct MMA_64x96x32_F16E5M2E4M3_SS_TN asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %26, 0;\n" - "wgmma.mma_async.sync.aligned.m64n96k32.f16.e5m2.e4m3 " + "setp.ne.b32 p, %34, 0;\n" + "wgmma.mma_async.sync.aligned.m64n128k32.f16.e4m3.e4m3 " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23}," - " %24," - " %25," - " p, %27, %28;\n" + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + " %32," + " %33," + " p, %35, %36;\n" "}\n" : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) : "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x96x32_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x128x32_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// GMMA 64x96x32 TN F16+=E5M2*E4M3 +// GMMA 64x128x32 TN F16+=E4M3*E4M3 template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct MMA_64x96x32_F16E5M2E4M3_RS_TN +struct MMA_64x128x32_F16E4M3E4M3_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[24]; + using CRegisters = uint32_t[32]; CUTE_HOST_DEVICE static void fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, @@ -42407,6 +14049,8 @@ struct MMA_64x96x32_F16E5M2E4M3_RS_TN uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) @@ -42414,43 +14058,46 @@ struct MMA_64x96x32_F16E5M2E4M3_RS_TN asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %29, 0;\n" - "wgmma.mma_async.sync.aligned.m64n96k32.f16.e5m2.e4m3 " + "setp.ne.b32 p, %37, 0;\n" + "wgmma.mma_async.sync.aligned.m64n128k32.f16.e4m3.e4m3 " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23}," - "{%24, %25, %26, %27}," - " %28," - " p, %30, %31;\n" + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + "{%32, %33, %34, %35}," + " %36," + " p, %38, %39;\n" "}\n" : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) : "r"(a00), "r"(a01), "r"(a02), "r"(a03), "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x96x32_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x128x32_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// GMMA 64x96x32 TN F32+=E5M2*E4M3 +// GMMA 64x128x32 TN F32+=E4M3*E4M3 template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct MMA_64x96x32_F32E5M2E4M3_SS_TN +struct MMA_64x128x32_F32E4M3E4M3_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = float[48]; + using CRegisters = float[64]; CUTE_HOST_DEVICE static void fma(uint64_t const& desc_a, @@ -42467,6 +14114,10 @@ struct MMA_64x96x32_F32E5M2E4M3_SS_TN float & d36, float & d37, float & d38, float & d39, float & d40, float & d41, float & d42, float & d43, float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) @@ -42474,17 +14125,19 @@ struct MMA_64x96x32_F32E5M2E4M3_SS_TN asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %50, 0;\n" - "wgmma.mma_async.sync.aligned.m64n96k32.f32.e5m2.e4m3 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47}," - " %48," - " %49," - " p, %51, %52;\n" + "setp.ne.b32 p, %66, 0;\n" + "wgmma.mma_async.sync.aligned.m64n128k32.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + " %64," + " %65," + " p, %67, %68;\n" "}\n" : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), @@ -42497,29 +14150,33 @@ struct MMA_64x96x32_F32E5M2E4M3_SS_TN "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), - "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47) + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63) : "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x96x32_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x128x32_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// GMMA 64x96x32 TN F32+=E5M2*E4M3 +// GMMA 64x128x32 TN F32+=E4M3*E4M3 template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct MMA_64x96x32_F32E5M2E4M3_RS_TN +struct MMA_64x128x32_F32E4M3E4M3_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; using BRegisters = uint64_t[1]; - using CRegisters = float[48]; + using CRegisters = float[64]; CUTE_HOST_DEVICE static void fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, @@ -42536,6 +14193,10 @@ struct MMA_64x96x32_F32E5M2E4M3_RS_TN float & d36, float & d37, float & d38, float & d39, float & d40, float & d41, float & d42, float & d43, float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) @@ -42543,17 +14204,19 @@ struct MMA_64x96x32_F32E5M2E4M3_RS_TN asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %53, 0;\n" - "wgmma.mma_async.sync.aligned.m64n96k32.f32.e5m2.e4m3 " + "setp.ne.b32 p, %69, 0;\n" + "wgmma.mma_async.sync.aligned.m64n128k32.f32.e4m3.e4m3 " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " " %16, %17, %18, %19, %20, %21, %22, %23, " " %24, %25, %26, %27, %28, %29, %30, %31, " " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47}," - "{%48, %49, %50, %51}," - " %52," - " p, %54, %55;\n" + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + "{%64, %65, %66, %67}," + " %68," + " p, %70, %71;\n" "}\n" : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), @@ -42566,30 +14229,33 @@ struct MMA_64x96x32_F32E5M2E4M3_RS_TN "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), - "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47) + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63) : "r"(a00), "r"(a01), "r"(a02), "r"(a03), "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x96x32_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x128x32_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x112x32 TN F16+=E5M2*E4M3 +// GMMA 64x192x32 TN F16+=E4M3*E4M3 template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct MMA_64x112x32_F16E5M2E4M3_SS_TN +struct MMA_64x192x32_F16E4M3E4M3_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[28]; + using CRegisters = uint32_t[48]; CUTE_HOST_DEVICE static void fma(uint64_t const& desc_a, @@ -42601,6 +14267,11 @@ struct MMA_64x112x32_F16E5M2E4M3_SS_TN uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) @@ -42608,15 +14279,17 @@ struct MMA_64x112x32_F16E5M2E4M3_SS_TN asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %30, 0;\n" - "wgmma.mma_async.sync.aligned.m64n112k32.f16.e5m2.e4m3 " + "setp.ne.b32 p, %50, 0;\n" + "wgmma.mma_async.sync.aligned.m64n192k32.f16.e4m3.e4m3 " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27}," - " %28," - " %29," - " p, %31, %32;\n" + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + " %48," + " %49," + " p, %51, %52;\n" "}\n" : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), @@ -42624,31 +14297,34 @@ struct MMA_64x112x32_F16E5M2E4M3_SS_TN "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27) + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) : "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x112x32_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x192x32_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x112x32 TN F16+=E5M2*E4M3 +// GMMA 64x192x32 TN F16+=E4M3*E4M3 template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct MMA_64x112x32_F16E5M2E4M3_RS_TN +struct MMA_64x192x32_F16E4M3E4M3_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[28]; + using CRegisters = uint32_t[48]; CUTE_HOST_DEVICE static void fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, @@ -42660,6 +14336,11 @@ struct MMA_64x112x32_F16E5M2E4M3_RS_TN uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) @@ -42667,15 +14348,17 @@ struct MMA_64x112x32_F16E5M2E4M3_RS_TN asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %33, 0;\n" - "wgmma.mma_async.sync.aligned.m64n112k32.f16.e5m2.e4m3 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27}," - "{%28, %29, %30, %31}," - " %32," - " p, %34, %35;\n" + "setp.ne.b32 p, %53, 0;\n" + "wgmma.mma_async.sync.aligned.m64n192k32.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + "{%48, %49, %50, %51}," + " %52," + " p, %54, %55;\n" "}\n" : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), @@ -42683,31 +14366,34 @@ struct MMA_64x112x32_F16E5M2E4M3_RS_TN "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27) + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) : "r"(a00), "r"(a01), "r"(a02), "r"(a03), "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x112x32_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x192x32_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x112x32 TN F32+=E5M2*E4M3 +// GMMA 64x192x32 TN F32+=E4M3*E4M3 template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct MMA_64x112x32_F32E5M2E4M3_SS_TN +struct MMA_64x192x32_F32E4M3E4M3_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = float[56]; + using CRegisters = float[96]; CUTE_HOST_DEVICE static void fma(uint64_t const& desc_a, @@ -42726,6 +14412,16 @@ struct MMA_64x112x32_F32E5M2E4M3_SS_TN float & d44, float & d45, float & d46, float & d47, float & d48, float & d49, float & d50, float & d51, float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + float & d88, float & d89, float & d90, float & d91, + float & d92, float & d93, float & d94, float & d95, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) @@ -42733,18 +14429,23 @@ struct MMA_64x112x32_F32E5M2E4M3_SS_TN asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %58, 0;\n" - "wgmma.mma_async.sync.aligned.m64n112k32.f32.e5m2.e4m3 " + "setp.ne.b32 p, %98, 0;\n" + "wgmma.mma_async.sync.aligned.m64n192k32.f32.e4m3.e4m3 " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " " %16, %17, %18, %19, %20, %21, %22, %23, " " %24, %25, %26, %27, %28, %29, %30, %31, " " %32, %33, %34, %35, %36, %37, %38, %39, " " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55}," - " %56," - " %57," - " p, %59, %60;\n" + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + " %96," + " %97," + " p, %99, %100;\n" "}\n" : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), @@ -42759,31 +14460,39 @@ struct MMA_64x112x32_F32E5M2E4M3_SS_TN "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), - "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55) + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), + "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91), + "+f"(d92), "+f"(d93), "+f"(d94), "+f"(d95) : "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x112x32_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x192x32_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x112x32 TN F32+=E5M2*E4M3 +// GMMA 64x192x32 TN F32+=E4M3*E4M3 template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct MMA_64x112x32_F32E5M2E4M3_RS_TN +struct MMA_64x192x32_F32E4M3E4M3_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; using BRegisters = uint64_t[1]; - using CRegisters = float[56]; + using CRegisters = float[96]; CUTE_HOST_DEVICE static void fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, @@ -42802,6 +14511,16 @@ struct MMA_64x112x32_F32E5M2E4M3_RS_TN float & d44, float & d45, float & d46, float & d47, float & d48, float & d49, float & d50, float & d51, float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + float & d88, float & d89, float & d90, float & d91, + float & d92, float & d93, float & d94, float & d95, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) @@ -42809,18 +14528,23 @@ struct MMA_64x112x32_F32E5M2E4M3_RS_TN asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %61, 0;\n" - "wgmma.mma_async.sync.aligned.m64n112k32.f32.e5m2.e4m3 " + "setp.ne.b32 p, %101, 0;\n" + "wgmma.mma_async.sync.aligned.m64n192k32.f32.e4m3.e4m3 " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " " %16, %17, %18, %19, %20, %21, %22, %23, " " %24, %25, %26, %27, %28, %29, %30, %31, " " %32, %33, %34, %35, %36, %37, %38, %39, " " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55}," - "{%56, %57, %58, %59}," - " %60," - " p, %62, %63;\n" + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + "{%96, %97, %98, %99}," + " %100," + " p, %102, %103;\n" "}\n" : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), @@ -42835,30 +14559,39 @@ struct MMA_64x112x32_F32E5M2E4M3_RS_TN "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), - "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55) + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), + "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91), + "+f"(d92), "+f"(d93), "+f"(d94), "+f"(d95) : "r"(a00), "r"(a01), "r"(a02), "r"(a03), "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x112x32_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x192x32_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -// GMMA 64x128x32 TN F16+=E5M2*E4M3 +// GMMA 64x256x32 TN F16+=E4M3*E4M3 template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct MMA_64x128x32_F16E5M2E4M3_SS_TN +struct MMA_64x256x32_F16E4M3E4M3_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[32]; + using CRegisters = uint32_t[64]; CUTE_HOST_DEVICE static void fma(uint64_t const& desc_a, @@ -42871,22 +14604,34 @@ struct MMA_64x128x32_F16E5M2E4M3_SS_TN uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %34, 0;\n" - "wgmma.mma_async.sync.aligned.m64n128k32.f16.e5m2.e4m3 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31}," - " %32," - " %33," - " p, %35, %36;\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %66, 0;\n" + "wgmma.mma_async.sync.aligned.m64n256k32.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + " %64," + " %65," + " p, %67, %68;\n" "}\n" : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), @@ -42895,29 +14640,37 @@ struct MMA_64x128x32_F16E5M2E4M3_SS_TN "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) : "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x128x32_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x256x32_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// GMMA 64x128x32 TN F16+=E5M2*E4M3 +// GMMA 64x256x32 TN F16+=E4M3*E4M3 template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct MMA_64x128x32_F16E5M2E4M3_RS_TN +struct MMA_64x256x32_F16E4M3E4M3_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[32]; + using CRegisters = uint32_t[64]; CUTE_HOST_DEVICE static void fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, @@ -42930,6 +14683,14 @@ struct MMA_64x128x32_F16E5M2E4M3_RS_TN uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) @@ -42937,15 +14698,19 @@ struct MMA_64x128x32_F16E5M2E4M3_RS_TN asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %37, 0;\n" - "wgmma.mma_async.sync.aligned.m64n128k32.f16.e5m2.e4m3 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31}," - "{%32, %33, %34, %35}," - " %36," - " p, %38, %39;\n" + "setp.ne.b32 p, %69, 0;\n" + "wgmma.mma_async.sync.aligned.m64n256k32.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + "{%64, %65, %66, %67}," + " %68," + " p, %70, %71;\n" "}\n" : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), @@ -42954,49 +14719,73 @@ struct MMA_64x128x32_F16E5M2E4M3_RS_TN "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) : "r"(a00), "r"(a01), "r"(a02), "r"(a03), "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x128x32_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x256x32_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// GMMA 64x128x32 TN F32+=E5M2*E4M3 +// GMMA 64x256x32 TN F32+=E4M3*E4M3 template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct MMA_64x128x32_F32E5M2E4M3_SS_TN +struct MMA_64x256x32_F32E4M3E4M3_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = float[64]; + using CRegisters = float[128]; CUTE_HOST_DEVICE static void fma(uint64_t const& desc_a, uint64_t const& desc_b, - float & d00, float & d01, float & d02, float & d03, - float & d04, float & d05, float & d06, float & d07, - float & d08, float & d09, float & d10, float & d11, - float & d12, float & d13, float & d14, float & d15, - float & d16, float & d17, float & d18, float & d19, - float & d20, float & d21, float & d22, float & d23, - float & d24, float & d25, float & d26, float & d27, - float & d28, float & d29, float & d30, float & d31, - float & d32, float & d33, float & d34, float & d35, - float & d36, float & d37, float & d38, float & d39, - float & d40, float & d41, float & d42, float & d43, - float & d44, float & d45, float & d46, float & d47, - float & d48, float & d49, float & d50, float & d51, - float & d52, float & d53, float & d54, float & d55, - float & d56, float & d57, float & d58, float & d59, - float & d60, float & d61, float & d62, float & d63, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + float & d120, float & d121, float & d122, float & d123, + float & d124, float & d125, float & d126, float & d127, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) @@ -43004,8 +14793,8 @@ struct MMA_64x128x32_F32E5M2E4M3_SS_TN asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %66, 0;\n" - "wgmma.mma_async.sync.aligned.m64n128k32.f32.e5m2.e4m3 " + "setp.ne.b32 p, %130, 0;\n" + "wgmma.mma_async.sync.aligned.m64n256k32.f32.e4m3.e4m3 " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " " %16, %17, %18, %19, %20, %21, %22, %23, " @@ -43013,69 +14802,109 @@ struct MMA_64x128x32_F32E5M2E4M3_SS_TN " %32, %33, %34, %35, %36, %37, %38, %39, " " %40, %41, %42, %43, %44, %45, %46, %47, " " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63}," - " %64," - " %65," - " p, %67, %68;\n" + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + " %128," + " %129," + " p, %131, %132;\n" "}\n" - : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), - "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), - "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), - "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), - "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), - "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), - "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), - "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), - "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), - "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), - "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), - "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), - "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), - "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), - "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), - "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63) + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), + "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123), + "+f"(d124), "+f"(d125), "+f"(d126), "+f"(d127) : "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x128x32_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x256x32_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// GMMA 64x128x32 TN F32+=E5M2*E4M3 +// GMMA 64x256x32 TN F32+=E4M3*E4M3 template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct MMA_64x128x32_F32E5M2E4M3_RS_TN +struct MMA_64x256x32_F32E4M3E4M3_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; using BRegisters = uint64_t[1]; - using CRegisters = float[64]; + using CRegisters = float[128]; CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, uint64_t const& desc_b, - float & d00, float & d01, float & d02, float & d03, - float & d04, float & d05, float & d06, float & d07, - float & d08, float & d09, float & d10, float & d11, - float & d12, float & d13, float & d14, float & d15, - float & d16, float & d17, float & d18, float & d19, - float & d20, float & d21, float & d22, float & d23, - float & d24, float & d25, float & d26, float & d27, - float & d28, float & d29, float & d30, float & d31, - float & d32, float & d33, float & d34, float & d35, - float & d36, float & d37, float & d38, float & d39, - float & d40, float & d41, float & d42, float & d43, - float & d44, float & d45, float & d46, float & d47, - float & d48, float & d49, float & d50, float & d51, - float & d52, float & d53, float & d54, float & d55, - float & d56, float & d57, float & d58, float & d59, - float & d60, float & d61, float & d62, float & d63, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + float & d120, float & d121, float & d122, float & d123, + float & d124, float & d125, float & d126, float & d127, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) @@ -43083,8 +14912,8 @@ struct MMA_64x128x32_F32E5M2E4M3_RS_TN asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %69, 0;\n" - "wgmma.mma_async.sync.aligned.m64n128k32.f32.e5m2.e4m3 " + "setp.ne.b32 p, %133, 0;\n" + "wgmma.mma_async.sync.aligned.m64n256k32.f32.e4m3.e4m3 " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " " %16, %17, %18, %19, %20, %21, %22, %23, " @@ -43092,63 +14921,78 @@ struct MMA_64x128x32_F32E5M2E4M3_RS_TN " %32, %33, %34, %35, %36, %37, %38, %39, " " %40, %41, %42, %43, %44, %45, %46, %47, " " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63}," - "{%64, %65, %66, %67}," - " %68," - " p, %70, %71;\n" + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + "{%128, %129, %130, %131}," + " %132," + " p, %134, %135;\n" "}\n" - : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), - "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), - "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), - "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), - "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), - "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), - "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), - "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), - "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), - "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), - "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), - "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), - "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), - "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), - "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), - "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), + "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123), + "+f"(d124), "+f"(d125), "+f"(d126), "+f"(d127) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x128x32_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x256x32_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x144x32 TN F16+=E5M2*E4M3 +// GMMA 64x8x32 TN F16+=E4M3*E5M2 template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct MMA_64x144x32_F16E5M2E4M3_SS_TN +struct MMA_64x8x32_F16E4M3E5M2_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[36]; + using CRegisters = uint32_t[2]; CUTE_HOST_DEVICE static void fma(uint64_t const& desc_a, uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d0, uint32_t & d1, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) @@ -43156,63 +15000,41 @@ struct MMA_64x144x32_F16E5M2E4M3_SS_TN asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %38, 0;\n" - "wgmma.mma_async.sync.aligned.m64n144k32.f16.e5m2.e4m3 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35}," - " %36," - " %37," - " p, %39, %40;\n" + "setp.ne.b32 p, %4, 0;\n" + "wgmma.mma_async.sync.aligned.m64n8k32.f16.e4m3.e5m2 " + "{%0, %1}," + " %2," + " %3," + " p, %5, %6;\n" "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35) + : "+r"(d0), "+r"(d1) : "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x144x32_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x8x32_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x144x32 TN F16+=E5M2*E4M3 +// GMMA 64x8x32 TN F16+=E4M3*E5M2 template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct MMA_64x144x32_F16E5M2E4M3_RS_TN +struct MMA_64x8x32_F16E4M3E5M2_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[36]; + using CRegisters = uint32_t[2]; CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d0, uint32_t & d1, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) @@ -43220,72 +15042,41 @@ struct MMA_64x144x32_F16E5M2E4M3_RS_TN asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %41, 0;\n" - "wgmma.mma_async.sync.aligned.m64n144k32.f16.e5m2.e4m3 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35}," - "{%36, %37, %38, %39}," - " %40," - " p, %42, %43;\n" + "setp.ne.b32 p, %7, 0;\n" + "wgmma.mma_async.sync.aligned.m64n8k32.f16.e4m3.e5m2 " + "{%0, %1}," + "{%2, %3, %4, %5}," + " %6," + " p, %8, %9;\n" "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + : "+r"(d0), "+r"(d1) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x144x32_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x8x32_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x144x32 TN F32+=E5M2*E4M3 +// GMMA 64x8x32 TN F32+=E4M3*E5M2 template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct MMA_64x144x32_F32E5M2E4M3_SS_TN +struct MMA_64x8x32_F32E4M3E5M2_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = float[72]; + using CRegisters = float[4]; CUTE_HOST_DEVICE static void fma(uint64_t const& desc_a, uint64_t const& desc_b, - float & d00, float & d01, float & d02, float & d03, - float & d04, float & d05, float & d06, float & d07, - float & d08, float & d09, float & d10, float & d11, - float & d12, float & d13, float & d14, float & d15, - float & d16, float & d17, float & d18, float & d19, - float & d20, float & d21, float & d22, float & d23, - float & d24, float & d25, float & d26, float & d27, - float & d28, float & d29, float & d30, float & d31, - float & d32, float & d33, float & d34, float & d35, - float & d36, float & d37, float & d38, float & d39, - float & d40, float & d41, float & d42, float & d43, - float & d44, float & d45, float & d46, float & d47, - float & d48, float & d49, float & d50, float & d51, - float & d52, float & d53, float & d54, float & d55, - float & d56, float & d57, float & d58, float & d59, - float & d60, float & d61, float & d62, float & d63, - float & d64, float & d65, float & d66, float & d67, - float & d68, float & d69, float & d70, float & d71, + float & d0, float & d1, float & d2, float & d3, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) @@ -43293,85 +15084,41 @@ struct MMA_64x144x32_F32E5M2E4M3_SS_TN asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %74, 0;\n" - "wgmma.mma_async.sync.aligned.m64n144k32.f32.e5m2.e4m3 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71}," - " %72," - " %73," - " p, %75, %76;\n" + "setp.ne.b32 p, %6, 0;\n" + "wgmma.mma_async.sync.aligned.m64n8k32.f32.e4m3.e5m2 " + "{%0, %1, %2, %3}," + " %4," + " %5," + " p, %7, %8;\n" "}\n" - : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), - "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), - "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), - "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), - "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), - "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), - "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), - "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), - "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), - "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), - "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), - "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), - "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), - "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), - "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), - "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), - "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), - "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71) + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3) : "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x144x32_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x8x32_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x144x32 TN F32+=E5M2*E4M3 +// GMMA 64x8x32 TN F32+=E4M3*E5M2 template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct MMA_64x144x32_F32E5M2E4M3_RS_TN +struct MMA_64x8x32_F32E4M3E5M2_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; using BRegisters = uint64_t[1]; - using CRegisters = float[72]; + using CRegisters = float[4]; CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, uint64_t const& desc_b, - float & d00, float & d01, float & d02, float & d03, - float & d04, float & d05, float & d06, float & d07, - float & d08, float & d09, float & d10, float & d11, - float & d12, float & d13, float & d14, float & d15, - float & d16, float & d17, float & d18, float & d19, - float & d20, float & d21, float & d22, float & d23, - float & d24, float & d25, float & d26, float & d27, - float & d28, float & d29, float & d30, float & d31, - float & d32, float & d33, float & d34, float & d35, - float & d36, float & d37, float & d38, float & d39, - float & d40, float & d41, float & d42, float & d43, - float & d44, float & d45, float & d46, float & d47, - float & d48, float & d49, float & d50, float & d51, - float & d52, float & d53, float & d54, float & d55, - float & d56, float & d57, float & d58, float & d59, - float & d60, float & d61, float & d62, float & d63, - float & d64, float & d65, float & d66, float & d67, - float & d68, float & d69, float & d70, float & d71, + float & d0, float & d1, float & d2, float & d3, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) @@ -43379,77 +15126,41 @@ struct MMA_64x144x32_F32E5M2E4M3_RS_TN asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %77, 0;\n" - "wgmma.mma_async.sync.aligned.m64n144k32.f32.e5m2.e4m3 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71}," - "{%72, %73, %74, %75}," - " %76," - " p, %78, %79;\n" - "}\n" - : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), - "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), - "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), - "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), - "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), - "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), - "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), - "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), - "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), - "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), - "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), - "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), - "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), - "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), - "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), - "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), - "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), - "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "setp.ne.b32 p, %9, 0;\n" + "wgmma.mma_async.sync.aligned.m64n8k32.f32.e4m3.e5m2 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + " %8," + " p, %10, %11;\n" + "}\n" + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x144x32_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x8x32_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x160x32 TN F16+=E5M2*E4M3 +// GMMA 64x16x32 TN F16+=E4M3*E5M2 template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct MMA_64x160x32_F16E5M2E4M3_SS_TN +struct MMA_64x16x32_F16E4M3E5M2_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[40]; + using CRegisters = uint32_t[4]; CUTE_HOST_DEVICE static void fma(uint64_t const& desc_a, uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) @@ -43457,65 +15168,41 @@ struct MMA_64x160x32_F16E5M2E4M3_SS_TN asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %42, 0;\n" - "wgmma.mma_async.sync.aligned.m64n160k32.f16.e5m2.e4m3 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39}," - " %40," - " %41," - " p, %43, %44;\n" + "setp.ne.b32 p, %6, 0;\n" + "wgmma.mma_async.sync.aligned.m64n16k32.f16.e4m3.e5m2 " + "{%0, %1, %2, %3}," + " %4," + " %5," + " p, %7, %8;\n" "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) : "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x160x32_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x16x32_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x160x32 TN F16+=E5M2*E4M3 +// GMMA 64x16x32 TN F16+=E4M3*E5M2 template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct MMA_64x160x32_F16E5M2E4M3_RS_TN +struct MMA_64x16x32_F16E4M3E5M2_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[40]; + using CRegisters = uint32_t[4]; CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) @@ -43523,75 +15210,42 @@ struct MMA_64x160x32_F16E5M2E4M3_RS_TN asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %45, 0;\n" - "wgmma.mma_async.sync.aligned.m64n160k32.f16.e5m2.e4m3 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39}," - "{%40, %41, %42, %43}," - " %44," - " p, %46, %47;\n" + "setp.ne.b32 p, %9, 0;\n" + "wgmma.mma_async.sync.aligned.m64n16k32.f16.e4m3.e5m2 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + " %8," + " p, %10, %11;\n" "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x160x32_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x16x32_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x160x32 TN F32+=E5M2*E4M3 +// GMMA 64x16x32 TN F32+=E4M3*E5M2 template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct MMA_64x160x32_F32E5M2E4M3_SS_TN +struct MMA_64x16x32_F32E4M3E5M2_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = float[80]; + using CRegisters = float[8]; CUTE_HOST_DEVICE static void fma(uint64_t const& desc_a, uint64_t const& desc_b, - float & d00, float & d01, float & d02, float & d03, - float & d04, float & d05, float & d06, float & d07, - float & d08, float & d09, float & d10, float & d11, - float & d12, float & d13, float & d14, float & d15, - float & d16, float & d17, float & d18, float & d19, - float & d20, float & d21, float & d22, float & d23, - float & d24, float & d25, float & d26, float & d27, - float & d28, float & d29, float & d30, float & d31, - float & d32, float & d33, float & d34, float & d35, - float & d36, float & d37, float & d38, float & d39, - float & d40, float & d41, float & d42, float & d43, - float & d44, float & d45, float & d46, float & d47, - float & d48, float & d49, float & d50, float & d51, - float & d52, float & d53, float & d54, float & d55, - float & d56, float & d57, float & d58, float & d59, - float & d60, float & d61, float & d62, float & d63, - float & d64, float & d65, float & d66, float & d67, - float & d68, float & d69, float & d70, float & d71, - float & d72, float & d73, float & d74, float & d75, - float & d76, float & d77, float & d78, float & d79, + float & d0, float & d1, float & d2, float & d3, + float & d4, float & d5, float & d6, float & d7, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) @@ -43599,90 +15253,43 @@ struct MMA_64x160x32_F32E5M2E4M3_SS_TN asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %82, 0;\n" - "wgmma.mma_async.sync.aligned.m64n160k32.f32.e5m2.e4m3 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79}," - " %80," - " %81," - " p, %83, %84;\n" + "setp.ne.b32 p, %10, 0;\n" + "wgmma.mma_async.sync.aligned.m64n16k32.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + " %8," + " %9," + " p, %11, %12;\n" "}\n" - : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), - "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), - "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), - "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), - "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), - "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), - "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), - "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), - "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), - "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), - "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), - "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), - "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), - "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), - "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), - "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), - "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), - "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), - "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), - "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79) + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3), + "+f"(d4), "+f"(d5), "+f"(d6), "+f"(d7) : "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x160x32_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x16x32_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x160x32 TN F32+=E5M2*E4M3 +// GMMA 64x16x32 TN F32+=E4M3*E5M2 template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct MMA_64x160x32_F32E5M2E4M3_RS_TN +struct MMA_64x16x32_F32E4M3E5M2_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; using BRegisters = uint64_t[1]; - using CRegisters = float[80]; + using CRegisters = float[8]; CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, uint64_t const& desc_b, - float & d00, float & d01, float & d02, float & d03, - float & d04, float & d05, float & d06, float & d07, - float & d08, float & d09, float & d10, float & d11, - float & d12, float & d13, float & d14, float & d15, - float & d16, float & d17, float & d18, float & d19, - float & d20, float & d21, float & d22, float & d23, - float & d24, float & d25, float & d26, float & d27, - float & d28, float & d29, float & d30, float & d31, - float & d32, float & d33, float & d34, float & d35, - float & d36, float & d37, float & d38, float & d39, - float & d40, float & d41, float & d42, float & d43, - float & d44, float & d45, float & d46, float & d47, - float & d48, float & d49, float & d50, float & d51, - float & d52, float & d53, float & d54, float & d55, - float & d56, float & d57, float & d58, float & d59, - float & d60, float & d61, float & d62, float & d63, - float & d64, float & d65, float & d66, float & d67, - float & d68, float & d69, float & d70, float & d71, - float & d72, float & d73, float & d74, float & d75, - float & d76, float & d77, float & d78, float & d79, + float & d0, float & d1, float & d2, float & d3, + float & d4, float & d5, float & d6, float & d7, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) @@ -43690,81 +15297,43 @@ struct MMA_64x160x32_F32E5M2E4M3_RS_TN asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %85, 0;\n" - "wgmma.mma_async.sync.aligned.m64n160k32.f32.e5m2.e4m3 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79}," - "{%80, %81, %82, %83}," - " %84," - " p, %86, %87;\n" + "setp.ne.b32 p, %13, 0;\n" + "wgmma.mma_async.sync.aligned.m64n16k32.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + "{%8, %9, %10, %11}," + " %12," + " p, %14, %15;\n" "}\n" - : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), - "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), - "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), - "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), - "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), - "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), - "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), - "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), - "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), - "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), - "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), - "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), - "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), - "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), - "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), - "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), - "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), - "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), - "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), - "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3), + "+f"(d4), "+f"(d5), "+f"(d6), "+f"(d7) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x160x32_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x16x32_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x176x32 TN F16+=E5M2*E4M3 +// GMMA 64x32x32 TN F16+=E4M3*E5M2 template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct MMA_64x176x32_F16E5M2E4M3_SS_TN +struct MMA_64x32x32_F16E4M3E5M2_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[44]; + using CRegisters = uint32_t[8]; CUTE_HOST_DEVICE static void fma(uint64_t const& desc_a, uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) @@ -43772,68 +15341,43 @@ struct MMA_64x176x32_F16E5M2E4M3_SS_TN asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %46, 0;\n" - "wgmma.mma_async.sync.aligned.m64n176k32.f16.e5m2.e4m3 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43}," - " %44," - " %45," - " p, %47, %48;\n" + "setp.ne.b32 p, %10, 0;\n" + "wgmma.mma_async.sync.aligned.m64n32k32.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + " %8," + " %9," + " p, %11, %12;\n" "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43) + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) : "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x176x32_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x32x32_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x176x32 TN F16+=E5M2*E4M3 +// GMMA 64x32x32 TN F16+=E4M3*E5M2 template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct MMA_64x176x32_F16E5M2E4M3_RS_TN +struct MMA_64x32x32_F16E4M3E5M2_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[44]; + using CRegisters = uint32_t[8]; CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) @@ -43841,53 +15385,37 @@ struct MMA_64x176x32_F16E5M2E4M3_RS_TN asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %49, 0;\n" - "wgmma.mma_async.sync.aligned.m64n176k32.f16.e5m2.e4m3 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43}," - "{%44, %45, %46, %47}," - " %48," - " p, %50, %51;\n" + "setp.ne.b32 p, %13, 0;\n" + "wgmma.mma_async.sync.aligned.m64n32k32.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + "{%8, %9, %10, %11}," + " %12," + " p, %14, %15;\n" "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x176x32_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x32x32_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x176x32 TN F32+=E5M2*E4M3 +// GMMA 64x32x32 TN F32+=E4M3*E5M2 template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct MMA_64x176x32_F32E5M2E4M3_SS_TN +struct MMA_64x32x32_F32E4M3E5M2_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = float[88]; + using CRegisters = float[16]; CUTE_HOST_DEVICE static void fma(uint64_t const& desc_a, @@ -43896,24 +15424,6 @@ struct MMA_64x176x32_F32E5M2E4M3_SS_TN float & d04, float & d05, float & d06, float & d07, float & d08, float & d09, float & d10, float & d11, float & d12, float & d13, float & d14, float & d15, - float & d16, float & d17, float & d18, float & d19, - float & d20, float & d21, float & d22, float & d23, - float & d24, float & d25, float & d26, float & d27, - float & d28, float & d29, float & d30, float & d31, - float & d32, float & d33, float & d34, float & d35, - float & d36, float & d37, float & d38, float & d39, - float & d40, float & d41, float & d42, float & d43, - float & d44, float & d45, float & d46, float & d47, - float & d48, float & d49, float & d50, float & d51, - float & d52, float & d53, float & d54, float & d55, - float & d56, float & d57, float & d58, float & d59, - float & d60, float & d61, float & d62, float & d63, - float & d64, float & d65, float & d66, float & d67, - float & d68, float & d69, float & d70, float & d71, - float & d72, float & d73, float & d74, float & d75, - float & d76, float & d77, float & d78, float & d79, - float & d80, float & d81, float & d82, float & d83, - float & d84, float & d85, float & d86, float & d87, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) @@ -43921,69 +15431,40 @@ struct MMA_64x176x32_F32E5M2E4M3_SS_TN asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %90, 0;\n" - "wgmma.mma_async.sync.aligned.m64n176k32.f32.e5m2.e4m3 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87}," - " %88," - " %89," - " p, %91, %92;\n" + "setp.ne.b32 p, %18, 0;\n" + "wgmma.mma_async.sync.aligned.m64n32k32.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + " %16," + " %17," + " p, %19, %20;\n" "}\n" : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), - "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), - "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), - "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), - "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), - "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), - "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), - "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), - "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), - "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), - "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), - "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), - "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), - "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), - "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), - "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), - "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), - "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), - "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), - "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87) + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15) : "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x176x32_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x32x32_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x176x32 TN F32+=E5M2*E4M3 +// GMMA 64x32x32 TN F32+=E4M3*E5M2 template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct MMA_64x176x32_F32E5M2E4M3_RS_TN +struct MMA_64x32x32_F32E4M3E5M2_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; using BRegisters = uint64_t[1]; - using CRegisters = float[88]; + using CRegisters = float[16]; CUTE_HOST_DEVICE static void fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, @@ -43992,24 +15473,6 @@ struct MMA_64x176x32_F32E5M2E4M3_RS_TN float & d04, float & d05, float & d06, float & d07, float & d08, float & d09, float & d10, float & d11, float & d12, float & d13, float & d14, float & d15, - float & d16, float & d17, float & d18, float & d19, - float & d20, float & d21, float & d22, float & d23, - float & d24, float & d25, float & d26, float & d27, - float & d28, float & d29, float & d30, float & d31, - float & d32, float & d33, float & d34, float & d35, - float & d36, float & d37, float & d38, float & d39, - float & d40, float & d41, float & d42, float & d43, - float & d44, float & d45, float & d46, float & d47, - float & d48, float & d49, float & d50, float & d51, - float & d52, float & d53, float & d54, float & d55, - float & d56, float & d57, float & d58, float & d59, - float & d60, float & d61, float & d62, float & d63, - float & d64, float & d65, float & d66, float & d67, - float & d68, float & d69, float & d70, float & d71, - float & d72, float & d73, float & d74, float & d75, - float & d76, float & d77, float & d78, float & d79, - float & d80, float & d81, float & d82, float & d83, - float & d84, float & d85, float & d86, float & d87, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) @@ -44017,68 +15480,40 @@ struct MMA_64x176x32_F32E5M2E4M3_RS_TN asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %93, 0;\n" - "wgmma.mma_async.sync.aligned.m64n176k32.f32.e5m2.e4m3 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87}," - "{%88, %89, %90, %91}," - " %92," - " p, %94, %95;\n" + "setp.ne.b32 p, %21, 0;\n" + "wgmma.mma_async.sync.aligned.m64n32k32.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + "{%16, %17, %18, %19}," + " %20," + " p, %22, %23;\n" "}\n" : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), - "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), - "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), - "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), - "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), - "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), - "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), - "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), - "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), - "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), - "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), - "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), - "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), - "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), - "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), - "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), - "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), - "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), - "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), - "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87) + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15) : "r"(a00), "r"(a01), "r"(a02), "r"(a03), "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x176x32_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x32x32_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -// GMMA 64x192x32 TN F16+=E5M2*E4M3 +// GMMA 64x64x32 TN F16+=E4M3*E5M2 template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct MMA_64x192x32_F16E5M2E4M3_SS_TN +struct MMA_64x64x32_F16E4M3E5M2_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[48]; + using CRegisters = uint32_t[16]; CUTE_HOST_DEVICE static void fma(uint64_t const& desc_a, @@ -44087,14 +15522,6 @@ struct MMA_64x192x32_F16E5M2E4M3_SS_TN uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) @@ -44102,52 +15529,40 @@ struct MMA_64x192x32_F16E5M2E4M3_SS_TN asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %50, 0;\n" - "wgmma.mma_async.sync.aligned.m64n192k32.f16.e5m2.e4m3 " + "setp.ne.b32 p, %18, 0;\n" + "wgmma.mma_async.sync.aligned.m64n64k32.f16.e4m3.e5m2 " "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47}," - " %48," - " %49," - " p, %51, %52;\n" + " %8, %9, %10, %11, %12, %13, %14, %15}," + " %16," + " %17," + " p, %19, %20;\n" "}\n" : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), - "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) : "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x192x32_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x64x32_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// GMMA 64x192x32 TN F16+=E5M2*E4M3 +// GMMA 64x64x32 TN F16+=E4M3*E5M2 template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct MMA_64x192x32_F16E5M2E4M3_RS_TN +struct MMA_64x64x32_F16E4M3E5M2_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[48]; + using CRegisters = uint32_t[16]; CUTE_HOST_DEVICE static void fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, @@ -44156,14 +15571,6 @@ struct MMA_64x192x32_F16E5M2E4M3_RS_TN uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) @@ -44171,52 +15578,40 @@ struct MMA_64x192x32_F16E5M2E4M3_RS_TN asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %53, 0;\n" - "wgmma.mma_async.sync.aligned.m64n192k32.f16.e5m2.e4m3 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47}," - "{%48, %49, %50, %51}," - " %52," - " p, %54, %55;\n" + "setp.ne.b32 p, %21, 0;\n" + "wgmma.mma_async.sync.aligned.m64n64k32.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + "{%16, %17, %18, %19}," + " %20," + " p, %22, %23;\n" "}\n" : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), - "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) : "r"(a00), "r"(a01), "r"(a02), "r"(a03), "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x192x32_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x64x32_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// GMMA 64x192x32 TN F32+=E5M2*E4M3 +// GMMA 64x64x32 TN F32+=E4M3*E5M2 template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct MMA_64x192x32_F32E5M2E4M3_SS_TN +struct MMA_64x64x32_F32E4M3E5M2_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = float[96]; + using CRegisters = float[32]; CUTE_HOST_DEVICE static void fma(uint64_t const& desc_a, @@ -44229,22 +15624,6 @@ struct MMA_64x192x32_F32E5M2E4M3_SS_TN float & d20, float & d21, float & d22, float & d23, float & d24, float & d25, float & d26, float & d27, float & d28, float & d29, float & d30, float & d31, - float & d32, float & d33, float & d34, float & d35, - float & d36, float & d37, float & d38, float & d39, - float & d40, float & d41, float & d42, float & d43, - float & d44, float & d45, float & d46, float & d47, - float & d48, float & d49, float & d50, float & d51, - float & d52, float & d53, float & d54, float & d55, - float & d56, float & d57, float & d58, float & d59, - float & d60, float & d61, float & d62, float & d63, - float & d64, float & d65, float & d66, float & d67, - float & d68, float & d69, float & d70, float & d71, - float & d72, float & d73, float & d74, float & d75, - float & d76, float & d77, float & d78, float & d79, - float & d80, float & d81, float & d82, float & d83, - float & d84, float & d85, float & d86, float & d87, - float & d88, float & d89, float & d90, float & d91, - float & d92, float & d93, float & d94, float & d95, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) @@ -44252,23 +15631,15 @@ struct MMA_64x192x32_F32E5M2E4M3_SS_TN asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %98, 0;\n" - "wgmma.mma_async.sync.aligned.m64n192k32.f32.e5m2.e4m3 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87, " - " %88, %89, %90, %91, %92, %93, %94, %95}," - " %96," - " %97," - " p, %99, %100;\n" + "setp.ne.b32 p, %34, 0;\n" + "wgmma.mma_async.sync.aligned.m64n64k32.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + " %32," + " %33," + " p, %35, %36;\n" "}\n" : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), @@ -44277,45 +15648,29 @@ struct MMA_64x192x32_F32E5M2E4M3_SS_TN "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), - "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), - "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), - "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), - "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), - "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), - "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), - "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), - "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), - "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), - "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), - "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), - "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), - "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), - "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), - "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), - "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91), - "+f"(d92), "+f"(d93), "+f"(d94), "+f"(d95) + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31) : "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x192x32_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x64x32_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// GMMA 64x192x32 TN F32+=E5M2*E4M3 +// GMMA 64x64x32 TN F32+=E4M3*E5M2 template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct MMA_64x192x32_F32E5M2E4M3_RS_TN +struct MMA_64x64x32_F32E4M3E5M2_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; using BRegisters = uint64_t[1]; - using CRegisters = float[96]; + using CRegisters = float[32]; CUTE_HOST_DEVICE static void fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, @@ -44328,22 +15683,6 @@ struct MMA_64x192x32_F32E5M2E4M3_RS_TN float & d20, float & d21, float & d22, float & d23, float & d24, float & d25, float & d26, float & d27, float & d28, float & d29, float & d30, float & d31, - float & d32, float & d33, float & d34, float & d35, - float & d36, float & d37, float & d38, float & d39, - float & d40, float & d41, float & d42, float & d43, - float & d44, float & d45, float & d46, float & d47, - float & d48, float & d49, float & d50, float & d51, - float & d52, float & d53, float & d54, float & d55, - float & d56, float & d57, float & d58, float & d59, - float & d60, float & d61, float & d62, float & d63, - float & d64, float & d65, float & d66, float & d67, - float & d68, float & d69, float & d70, float & d71, - float & d72, float & d73, float & d74, float & d75, - float & d76, float & d77, float & d78, float & d79, - float & d80, float & d81, float & d82, float & d83, - float & d84, float & d85, float & d86, float & d87, - float & d88, float & d89, float & d90, float & d91, - float & d92, float & d93, float & d94, float & d95, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) @@ -44351,23 +15690,15 @@ struct MMA_64x192x32_F32E5M2E4M3_RS_TN asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %101, 0;\n" - "wgmma.mma_async.sync.aligned.m64n192k32.f32.e5m2.e4m3 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87, " - " %88, %89, %90, %91, %92, %93, %94, %95}," - "{%96, %97, %98, %99}," - " %100," - " p, %102, %103;\n" + "setp.ne.b32 p, %37, 0;\n" + "wgmma.mma_async.sync.aligned.m64n64k32.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + "{%32, %33, %34, %35}," + " %36," + " p, %38, %39;\n" "}\n" : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), @@ -44376,46 +15707,29 @@ struct MMA_64x192x32_F32E5M2E4M3_RS_TN "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), - "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), - "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), - "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), - "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), - "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), - "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), - "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), - "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), - "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), - "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), - "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), - "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), - "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), - "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), - "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), - "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91), - "+f"(d92), "+f"(d93), "+f"(d94), "+f"(d95) + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31) : "r"(a00), "r"(a01), "r"(a02), "r"(a03), "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x192x32_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x64x32_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x208x32 TN F16+=E5M2*E4M3 +// GMMA 64x96x32 TN F16+=E4M3*E5M2 template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct MMA_64x208x32_F16E5M2E4M3_SS_TN +struct MMA_64x96x32_F16E4M3E5M2_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[52]; + using CRegisters = uint32_t[24]; CUTE_HOST_DEVICE static void fma(uint64_t const& desc_a, @@ -44426,13 +15740,6 @@ struct MMA_64x208x32_F16E5M2E4M3_SS_TN uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, - uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) @@ -44440,56 +15747,43 @@ struct MMA_64x208x32_F16E5M2E4M3_SS_TN asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %54, 0;\n" - "wgmma.mma_async.sync.aligned.m64n208k32.f16.e5m2.e4m3 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51}," - " %52," - " %53," - " p, %55, %56;\n" + "setp.ne.b32 p, %26, 0;\n" + "wgmma.mma_async.sync.aligned.m64n96k32.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + " %24," + " %25," + " p, %27, %28;\n" "}\n" : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), - "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), - "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51) + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) : "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x208x32_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x96x32_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x208x32 TN F16+=E5M2*E4M3 +// GMMA 64x96x32 TN F16+=E4M3*E5M2 template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct MMA_64x208x32_F16E5M2E4M3_RS_TN +struct MMA_64x96x32_F16E4M3E5M2_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[52]; + using CRegisters = uint32_t[24]; CUTE_HOST_DEVICE static void fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, @@ -44500,13 +15794,6 @@ struct MMA_64x208x32_F16E5M2E4M3_RS_TN uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, - uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) @@ -44514,86 +15801,59 @@ struct MMA_64x208x32_F16E5M2E4M3_RS_TN asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %57, 0;\n" - "wgmma.mma_async.sync.aligned.m64n208k32.f16.e5m2.e4m3 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51}," - "{%52, %53, %54, %55}," - " %56," - " p, %58, %59;\n" + "setp.ne.b32 p, %29, 0;\n" + "wgmma.mma_async.sync.aligned.m64n96k32.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + "{%24, %25, %26, %27}," + " %28," + " p, %30, %31;\n" "}\n" : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), - "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), - "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51) + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) : "r"(a00), "r"(a01), "r"(a02), "r"(a03), "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x208x32_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x96x32_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x208x32 TN F32+=E5M2*E4M3 +// GMMA 64x96x32 TN F32+=E4M3*E5M2 template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct MMA_64x208x32_F32E5M2E4M3_SS_TN +struct MMA_64x96x32_F32E4M3E5M2_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = float[104]; + using CRegisters = float[48]; CUTE_HOST_DEVICE static void fma(uint64_t const& desc_a, uint64_t const& desc_b, - float & d000, float & d001, float & d002, float & d003, - float & d004, float & d005, float & d006, float & d007, - float & d008, float & d009, float & d010, float & d011, - float & d012, float & d013, float & d014, float & d015, - float & d016, float & d017, float & d018, float & d019, - float & d020, float & d021, float & d022, float & d023, - float & d024, float & d025, float & d026, float & d027, - float & d028, float & d029, float & d030, float & d031, - float & d032, float & d033, float & d034, float & d035, - float & d036, float & d037, float & d038, float & d039, - float & d040, float & d041, float & d042, float & d043, - float & d044, float & d045, float & d046, float & d047, - float & d048, float & d049, float & d050, float & d051, - float & d052, float & d053, float & d054, float & d055, - float & d056, float & d057, float & d058, float & d059, - float & d060, float & d061, float & d062, float & d063, - float & d064, float & d065, float & d066, float & d067, - float & d068, float & d069, float & d070, float & d071, - float & d072, float & d073, float & d074, float & d075, - float & d076, float & d077, float & d078, float & d079, - float & d080, float & d081, float & d082, float & d083, - float & d084, float & d085, float & d086, float & d087, - float & d088, float & d089, float & d090, float & d091, - float & d092, float & d093, float & d094, float & d095, - float & d096, float & d097, float & d098, float & d099, - float & d100, float & d101, float & d102, float & d103, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) @@ -44601,105 +15861,68 @@ struct MMA_64x208x32_F32E5M2E4M3_SS_TN asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %106, 0;\n" - "wgmma.mma_async.sync.aligned.m64n208k32.f32.e5m2.e4m3 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87, " - " %88, %89, %90, %91, %92, %93, %94, %95, " - " %96, %97, %98, %99, %100, %101, %102, %103}," - " %104," - " %105," - " p, %107, %108;\n" + "setp.ne.b32 p, %50, 0;\n" + "wgmma.mma_async.sync.aligned.m64n96k32.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + " %48," + " %49," + " p, %51, %52;\n" "}\n" - : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), - "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), - "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), - "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), - "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), - "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), - "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), - "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), - "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), - "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), - "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), - "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), - "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), - "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), - "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), - "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), - "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), - "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), - "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), - "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), - "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), - "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), - "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), - "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), - "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), - "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103) + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47) : "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x208x32_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x96x32_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x208x32 TN F32+=E5M2*E4M3 +// GMMA 64x96x32 TN F32+=E4M3*E5M2 template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct MMA_64x208x32_F32E5M2E4M3_RS_TN +struct MMA_64x96x32_F32E4M3E5M2_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; using BRegisters = uint64_t[1]; - using CRegisters = float[104]; + using CRegisters = float[48]; CUTE_HOST_DEVICE static void - fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, uint64_t const& desc_b, - float & d000, float & d001, float & d002, float & d003, - float & d004, float & d005, float & d006, float & d007, - float & d008, float & d009, float & d010, float & d011, - float & d012, float & d013, float & d014, float & d015, - float & d016, float & d017, float & d018, float & d019, - float & d020, float & d021, float & d022, float & d023, - float & d024, float & d025, float & d026, float & d027, - float & d028, float & d029, float & d030, float & d031, - float & d032, float & d033, float & d034, float & d035, - float & d036, float & d037, float & d038, float & d039, - float & d040, float & d041, float & d042, float & d043, - float & d044, float & d045, float & d046, float & d047, - float & d048, float & d049, float & d050, float & d051, - float & d052, float & d053, float & d054, float & d055, - float & d056, float & d057, float & d058, float & d059, - float & d060, float & d061, float & d062, float & d063, - float & d064, float & d065, float & d066, float & d067, - float & d068, float & d069, float & d070, float & d071, - float & d072, float & d073, float & d074, float & d075, - float & d076, float & d077, float & d078, float & d079, - float & d080, float & d081, float & d082, float & d083, - float & d084, float & d085, float & d086, float & d087, - float & d088, float & d089, float & d090, float & d091, - float & d092, float & d093, float & d094, float & d095, - float & d096, float & d097, float & d098, float & d099, - float & d100, float & d101, float & d102, float & d103, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) @@ -44707,75 +15930,52 @@ struct MMA_64x208x32_F32E5M2E4M3_RS_TN asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %109, 0;\n" - "wgmma.mma_async.sync.aligned.m64n208k32.f32.e5m2.e4m3 " + "setp.ne.b32 p, %53, 0;\n" + "wgmma.mma_async.sync.aligned.m64n96k32.f32.e4m3.e5m2 " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " " %16, %17, %18, %19, %20, %21, %22, %23, " " %24, %25, %26, %27, %28, %29, %30, %31, " " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87, " - " %88, %89, %90, %91, %92, %93, %94, %95, " - " %96, %97, %98, %99, %100, %101, %102, %103}," - "{%104, %105, %106, %107}," - " %108," - " p, %110, %111;\n" + " %40, %41, %42, %43, %44, %45, %46, %47}," + "{%48, %49, %50, %51}," + " %52," + " p, %54, %55;\n" "}\n" - : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), - "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), - "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), - "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), - "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), - "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), - "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), - "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), - "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), - "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), - "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), - "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), - "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), - "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), - "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), - "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), - "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), - "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), - "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), - "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), - "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), - "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), - "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), - "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), - "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), - "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103) - : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x208x32_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x96x32_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x224x32 TN F16+=E5M2*E4M3 +// GMMA 64x128x32 TN F16+=E4M3*E5M2 template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct MMA_64x224x32_F16E5M2E4M3_SS_TN +struct MMA_64x128x32_F16E4M3E5M2_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[56]; + using CRegisters = uint32_t[32]; CUTE_HOST_DEVICE static void fma(uint64_t const& desc_a, @@ -44788,12 +15988,6 @@ struct MMA_64x224x32_F16E5M2E4M3_SS_TN uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, - uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, - uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) @@ -44801,18 +15995,15 @@ struct MMA_64x224x32_F16E5M2E4M3_SS_TN asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %58, 0;\n" - "wgmma.mma_async.sync.aligned.m64n224k32.f16.e5m2.e4m3 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55}," - " %56," - " %57," - " p, %59, %60;\n" + "setp.ne.b32 p, %34, 0;\n" + "wgmma.mma_async.sync.aligned.m64n128k32.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + " %32," + " %33," + " p, %35, %36;\n" "}\n" : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), @@ -44821,37 +16012,29 @@ struct MMA_64x224x32_F16E5M2E4M3_SS_TN "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), - "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), - "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), - "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) : "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x224x32_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x128x32_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x224x32 TN F16+=E5M2*E4M3 +// GMMA 64x128x32 TN F16+=E4M3*E5M2 template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct MMA_64x224x32_F16E5M2E4M3_RS_TN +struct MMA_64x128x32_F16E4M3E5M2_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[56]; + using CRegisters = uint32_t[32]; CUTE_HOST_DEVICE static void fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, @@ -44864,12 +16047,6 @@ struct MMA_64x224x32_F16E5M2E4M3_RS_TN uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, - uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, - uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) @@ -44877,18 +16054,15 @@ struct MMA_64x224x32_F16E5M2E4M3_RS_TN asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %61, 0;\n" - "wgmma.mma_async.sync.aligned.m64n224k32.f16.e5m2.e4m3 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55}," - "{%56, %57, %58, %59}," - " %60," - " p, %62, %63;\n" + "setp.ne.b32 p, %37, 0;\n" + "wgmma.mma_async.sync.aligned.m64n128k32.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + "{%32, %33, %34, %35}," + " %36," + " p, %38, %39;\n" "}\n" : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), @@ -44897,69 +16071,49 @@ struct MMA_64x224x32_F16E5M2E4M3_RS_TN "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), - "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), - "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), - "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) : "r"(a00), "r"(a01), "r"(a02), "r"(a03), "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x224x32_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x128x32_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x224x32 TN F32+=E5M2*E4M3 +// GMMA 64x128x32 TN F32+=E4M3*E5M2 template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct MMA_64x224x32_F32E5M2E4M3_SS_TN +struct MMA_64x128x32_F32E4M3E5M2_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = float[112]; + using CRegisters = float[64]; CUTE_HOST_DEVICE static void fma(uint64_t const& desc_a, uint64_t const& desc_b, - float & d000, float & d001, float & d002, float & d003, - float & d004, float & d005, float & d006, float & d007, - float & d008, float & d009, float & d010, float & d011, - float & d012, float & d013, float & d014, float & d015, - float & d016, float & d017, float & d018, float & d019, - float & d020, float & d021, float & d022, float & d023, - float & d024, float & d025, float & d026, float & d027, - float & d028, float & d029, float & d030, float & d031, - float & d032, float & d033, float & d034, float & d035, - float & d036, float & d037, float & d038, float & d039, - float & d040, float & d041, float & d042, float & d043, - float & d044, float & d045, float & d046, float & d047, - float & d048, float & d049, float & d050, float & d051, - float & d052, float & d053, float & d054, float & d055, - float & d056, float & d057, float & d058, float & d059, - float & d060, float & d061, float & d062, float & d063, - float & d064, float & d065, float & d066, float & d067, - float & d068, float & d069, float & d070, float & d071, - float & d072, float & d073, float & d074, float & d075, - float & d076, float & d077, float & d078, float & d079, - float & d080, float & d081, float & d082, float & d083, - float & d084, float & d085, float & d086, float & d087, - float & d088, float & d089, float & d090, float & d091, - float & d092, float & d093, float & d094, float & d095, - float & d096, float & d097, float & d098, float & d099, - float & d100, float & d101, float & d102, float & d103, - float & d104, float & d105, float & d106, float & d107, - float & d108, float & d109, float & d110, float & d111, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) @@ -44967,8 +16121,8 @@ struct MMA_64x224x32_F32E5M2E4M3_SS_TN asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %114, 0;\n" - "wgmma.mma_async.sync.aligned.m64n224k32.f32.e5m2.e4m3 " + "setp.ne.b32 p, %66, 0;\n" + "wgmma.mma_async.sync.aligned.m64n128k32.f32.e4m3.e5m2 " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " " %16, %17, %18, %19, %20, %21, %22, %23, " @@ -44976,101 +16130,69 @@ struct MMA_64x224x32_F32E5M2E4M3_SS_TN " %32, %33, %34, %35, %36, %37, %38, %39, " " %40, %41, %42, %43, %44, %45, %46, %47, " " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87, " - " %88, %89, %90, %91, %92, %93, %94, %95, " - " %96, %97, %98, %99, %100, %101, %102, %103, " - " %104, %105, %106, %107, %108, %109, %110, %111}," - " %112," - " %113," - " p, %115, %116;\n" + " %56, %57, %58, %59, %60, %61, %62, %63}," + " %64," + " %65," + " p, %67, %68;\n" "}\n" - : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), - "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), - "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), - "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), - "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), - "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), - "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), - "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), - "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), - "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), - "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), - "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), - "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), - "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), - "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), - "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), - "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), - "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), - "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), - "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), - "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), - "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), - "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), - "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), - "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), - "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), - "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), - "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111) + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63) : "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x224x32_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x128x32_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x224x32 TN F32+=E5M2*E4M3 +// GMMA 64x128x32 TN F32+=E4M3*E5M2 template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct MMA_64x224x32_F32E5M2E4M3_RS_TN +struct MMA_64x128x32_F32E4M3E5M2_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; using BRegisters = uint64_t[1]; - using CRegisters = float[112]; + using CRegisters = float[64]; CUTE_HOST_DEVICE static void - fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, uint64_t const& desc_b, - float & d000, float & d001, float & d002, float & d003, - float & d004, float & d005, float & d006, float & d007, - float & d008, float & d009, float & d010, float & d011, - float & d012, float & d013, float & d014, float & d015, - float & d016, float & d017, float & d018, float & d019, - float & d020, float & d021, float & d022, float & d023, - float & d024, float & d025, float & d026, float & d027, - float & d028, float & d029, float & d030, float & d031, - float & d032, float & d033, float & d034, float & d035, - float & d036, float & d037, float & d038, float & d039, - float & d040, float & d041, float & d042, float & d043, - float & d044, float & d045, float & d046, float & d047, - float & d048, float & d049, float & d050, float & d051, - float & d052, float & d053, float & d054, float & d055, - float & d056, float & d057, float & d058, float & d059, - float & d060, float & d061, float & d062, float & d063, - float & d064, float & d065, float & d066, float & d067, - float & d068, float & d069, float & d070, float & d071, - float & d072, float & d073, float & d074, float & d075, - float & d076, float & d077, float & d078, float & d079, - float & d080, float & d081, float & d082, float & d083, - float & d084, float & d085, float & d086, float & d087, - float & d088, float & d089, float & d090, float & d091, - float & d092, float & d093, float & d094, float & d095, - float & d096, float & d097, float & d098, float & d099, - float & d100, float & d101, float & d102, float & d103, - float & d104, float & d105, float & d106, float & d107, - float & d108, float & d109, float & d110, float & d111, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) @@ -45078,8 +16200,8 @@ struct MMA_64x224x32_F32E5M2E4M3_RS_TN asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %117, 0;\n" - "wgmma.mma_async.sync.aligned.m64n224k32.f32.e5m2.e4m3 " + "setp.ne.b32 p, %69, 0;\n" + "wgmma.mma_async.sync.aligned.m64n128k32.f32.e4m3.e5m2 " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " " %16, %17, %18, %19, %20, %21, %22, %23, " @@ -45087,69 +16209,49 @@ struct MMA_64x224x32_F32E5M2E4M3_RS_TN " %32, %33, %34, %35, %36, %37, %38, %39, " " %40, %41, %42, %43, %44, %45, %46, %47, " " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87, " - " %88, %89, %90, %91, %92, %93, %94, %95, " - " %96, %97, %98, %99, %100, %101, %102, %103, " - " %104, %105, %106, %107, %108, %109, %110, %111}," - "{%112, %113, %114, %115}," - " %116," - " p, %118, %119;\n" + " %56, %57, %58, %59, %60, %61, %62, %63}," + "{%64, %65, %66, %67}," + " %68," + " p, %70, %71;\n" "}\n" - : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), - "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), - "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), - "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), - "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), - "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), - "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), - "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), - "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), - "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), - "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), - "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), - "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), - "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), - "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), - "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), - "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), - "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), - "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), - "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), - "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), - "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), - "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), - "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), - "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), - "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), - "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), - "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111) - : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x224x32_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x128x32_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x240x32 TN F16+=E5M2*E4M3 +// GMMA 64x192x32 TN F16+=E4M3*E5M2 template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct MMA_64x240x32_F16E5M2E4M3_SS_TN +struct MMA_64x192x32_F16E4M3E5M2_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[60]; + using CRegisters = uint32_t[48]; CUTE_HOST_DEVICE static void fma(uint64_t const& desc_a, @@ -45166,9 +16268,6 @@ struct MMA_64x240x32_F16E5M2E4M3_SS_TN uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, - uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, - uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, - uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) @@ -45176,19 +16275,17 @@ struct MMA_64x240x32_F16E5M2E4M3_SS_TN asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %62, 0;\n" - "wgmma.mma_async.sync.aligned.m64n240k32.f16.e5m2.e4m3 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59}," - " %60," - " %61," - " p, %63, %64;\n" + "setp.ne.b32 p, %50, 0;\n" + "wgmma.mma_async.sync.aligned.m64n192k32.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + " %48," + " %49," + " p, %51, %52;\n" "}\n" : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), @@ -45201,34 +16298,29 @@ struct MMA_64x240x32_F16E5M2E4M3_SS_TN "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), - "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), - "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), - "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), - "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59) + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) : "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x240x32_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x192x32_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x240x32 TN F16+=E5M2*E4M3 +// GMMA 64x192x32 TN F16+=E4M3*E5M2 template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct MMA_64x240x32_F16E5M2E4M3_RS_TN +struct MMA_64x192x32_F16E4M3E5M2_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[60]; + using CRegisters = uint32_t[48]; CUTE_HOST_DEVICE static void fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, @@ -45245,9 +16337,6 @@ struct MMA_64x240x32_F16E5M2E4M3_RS_TN uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, - uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, - uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, - uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) @@ -45255,19 +16344,17 @@ struct MMA_64x240x32_F16E5M2E4M3_RS_TN asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %65, 0;\n" - "wgmma.mma_async.sync.aligned.m64n240k32.f16.e5m2.e4m3 " + "setp.ne.b32 p, %53, 0;\n" + "wgmma.mma_async.sync.aligned.m64n192k32.f16.e4m3.e5m2 " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " " %16, %17, %18, %19, %20, %21, %22, %23, " " %24, %25, %26, %27, %28, %29, %30, %31, " " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59}," - "{%60, %61, %62, %63}," - " %64," - " p, %66, %67;\n" + " %40, %41, %42, %43, %44, %45, %46, %47}," + "{%48, %49, %50, %51}," + " %52," + " p, %54, %55;\n" "}\n" : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), @@ -45280,68 +16367,57 @@ struct MMA_64x240x32_F16E5M2E4M3_RS_TN "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), - "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), - "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), - "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), - "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59) + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) : "r"(a00), "r"(a01), "r"(a02), "r"(a03), "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x240x32_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x192x32_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x240x32 TN F32+=E5M2*E4M3 +// GMMA 64x192x32 TN F32+=E4M3*E5M2 template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct MMA_64x240x32_F32E5M2E4M3_SS_TN +struct MMA_64x192x32_F32E4M3E5M2_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = float[120]; + using CRegisters = float[96]; CUTE_HOST_DEVICE static void fma(uint64_t const& desc_a, uint64_t const& desc_b, - float & d000, float & d001, float & d002, float & d003, - float & d004, float & d005, float & d006, float & d007, - float & d008, float & d009, float & d010, float & d011, - float & d012, float & d013, float & d014, float & d015, - float & d016, float & d017, float & d018, float & d019, - float & d020, float & d021, float & d022, float & d023, - float & d024, float & d025, float & d026, float & d027, - float & d028, float & d029, float & d030, float & d031, - float & d032, float & d033, float & d034, float & d035, - float & d036, float & d037, float & d038, float & d039, - float & d040, float & d041, float & d042, float & d043, - float & d044, float & d045, float & d046, float & d047, - float & d048, float & d049, float & d050, float & d051, - float & d052, float & d053, float & d054, float & d055, - float & d056, float & d057, float & d058, float & d059, - float & d060, float & d061, float & d062, float & d063, - float & d064, float & d065, float & d066, float & d067, - float & d068, float & d069, float & d070, float & d071, - float & d072, float & d073, float & d074, float & d075, - float & d076, float & d077, float & d078, float & d079, - float & d080, float & d081, float & d082, float & d083, - float & d084, float & d085, float & d086, float & d087, - float & d088, float & d089, float & d090, float & d091, - float & d092, float & d093, float & d094, float & d095, - float & d096, float & d097, float & d098, float & d099, - float & d100, float & d101, float & d102, float & d103, - float & d104, float & d105, float & d106, float & d107, - float & d108, float & d109, float & d110, float & d111, - float & d112, float & d113, float & d114, float & d115, - float & d116, float & d117, float & d118, float & d119, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + float & d88, float & d89, float & d90, float & d91, + float & d92, float & d93, float & d94, float & d95, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) @@ -45349,8 +16425,8 @@ struct MMA_64x240x32_F32E5M2E4M3_SS_TN asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %122, 0;\n" - "wgmma.mma_async.sync.aligned.m64n240k32.f32.e5m2.e4m3 " + "setp.ne.b32 p, %98, 0;\n" + "wgmma.mma_async.sync.aligned.m64n192k32.f32.e4m3.e5m2 " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " " %16, %17, %18, %19, %20, %21, %22, %23, " @@ -45362,102 +16438,85 @@ struct MMA_64x240x32_F32E5M2E4M3_SS_TN " %64, %65, %66, %67, %68, %69, %70, %71, " " %72, %73, %74, %75, %76, %77, %78, %79, " " %80, %81, %82, %83, %84, %85, %86, %87, " - " %88, %89, %90, %91, %92, %93, %94, %95, " - " %96, %97, %98, %99, %100, %101, %102, %103, " - " %104, %105, %106, %107, %108, %109, %110, %111, " - " %112, %113, %114, %115, %116, %117, %118, %119}," - " %120," - " %121," - " p, %123, %124;\n" + " %88, %89, %90, %91, %92, %93, %94, %95}," + " %96," + " %97," + " p, %99, %100;\n" "}\n" - : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), - "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), - "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), - "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), - "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), - "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), - "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), - "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), - "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), - "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), - "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), - "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), - "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), - "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), - "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), - "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), - "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), - "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), - "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), - "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), - "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), - "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), - "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), - "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), - "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), - "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), - "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), - "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), - "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), - "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119) + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), + "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91), + "+f"(d92), "+f"(d93), "+f"(d94), "+f"(d95) : "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x240x32_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x192x32_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x240x32 TN F32+=E5M2*E4M3 +// GMMA 64x192x32 TN F32+=E4M3*E5M2 template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct MMA_64x240x32_F32E5M2E4M3_RS_TN +struct MMA_64x192x32_F32E4M3E5M2_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; using BRegisters = uint64_t[1]; - using CRegisters = float[120]; + using CRegisters = float[96]; CUTE_HOST_DEVICE static void - fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, uint64_t const& desc_b, - float & d000, float & d001, float & d002, float & d003, - float & d004, float & d005, float & d006, float & d007, - float & d008, float & d009, float & d010, float & d011, - float & d012, float & d013, float & d014, float & d015, - float & d016, float & d017, float & d018, float & d019, - float & d020, float & d021, float & d022, float & d023, - float & d024, float & d025, float & d026, float & d027, - float & d028, float & d029, float & d030, float & d031, - float & d032, float & d033, float & d034, float & d035, - float & d036, float & d037, float & d038, float & d039, - float & d040, float & d041, float & d042, float & d043, - float & d044, float & d045, float & d046, float & d047, - float & d048, float & d049, float & d050, float & d051, - float & d052, float & d053, float & d054, float & d055, - float & d056, float & d057, float & d058, float & d059, - float & d060, float & d061, float & d062, float & d063, - float & d064, float & d065, float & d066, float & d067, - float & d068, float & d069, float & d070, float & d071, - float & d072, float & d073, float & d074, float & d075, - float & d076, float & d077, float & d078, float & d079, - float & d080, float & d081, float & d082, float & d083, - float & d084, float & d085, float & d086, float & d087, - float & d088, float & d089, float & d090, float & d091, - float & d092, float & d093, float & d094, float & d095, - float & d096, float & d097, float & d098, float & d099, - float & d100, float & d101, float & d102, float & d103, - float & d104, float & d105, float & d106, float & d107, - float & d108, float & d109, float & d110, float & d111, - float & d112, float & d113, float & d114, float & d115, - float & d116, float & d117, float & d118, float & d119, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + float & d88, float & d89, float & d90, float & d91, + float & d92, float & d93, float & d94, float & d95, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) @@ -45465,8 +16524,8 @@ struct MMA_64x240x32_F32E5M2E4M3_RS_TN asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %125, 0;\n" - "wgmma.mma_async.sync.aligned.m64n240k32.f32.e5m2.e4m3 " + "setp.ne.b32 p, %101, 0;\n" + "wgmma.mma_async.sync.aligned.m64n192k32.f32.e4m3.e5m2 " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " " %16, %17, %18, %19, %20, %21, %22, %23, " @@ -45478,62 +16537,52 @@ struct MMA_64x240x32_F32E5M2E4M3_RS_TN " %64, %65, %66, %67, %68, %69, %70, %71, " " %72, %73, %74, %75, %76, %77, %78, %79, " " %80, %81, %82, %83, %84, %85, %86, %87, " - " %88, %89, %90, %91, %92, %93, %94, %95, " - " %96, %97, %98, %99, %100, %101, %102, %103, " - " %104, %105, %106, %107, %108, %109, %110, %111, " - " %112, %113, %114, %115, %116, %117, %118, %119}," - "{%120, %121, %122, %123}," - " %124," - " p, %126, %127;\n" + " %88, %89, %90, %91, %92, %93, %94, %95}," + "{%96, %97, %98, %99}," + " %100," + " p, %102, %103;\n" "}\n" - : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), - "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), - "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), - "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), - "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), - "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), - "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), - "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), - "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), - "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), - "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), - "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), - "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), - "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), - "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), - "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), - "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), - "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), - "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), - "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), - "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), - "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), - "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), - "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), - "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), - "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), - "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), - "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), - "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), - "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119) - : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), + "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91), + "+f"(d92), "+f"(d93), "+f"(d94), "+f"(d95) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x240x32_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x192x32_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -// GMMA 64x256x32 TN F16+=E5M2*E4M3 +// GMMA 64x256x32 TN F16+=E4M3*E5M2 template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct MMA_64x256x32_F16E5M2E4M3_SS_TN +struct MMA_64x256x32_F16E4M3E5M2_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -45567,7 +16616,7 @@ struct MMA_64x256x32_F16E5M2E4M3_SS_TN "{\n" ".reg .pred p;\n" "setp.ne.b32 p, %66, 0;\n" - "wgmma.mma_async.sync.aligned.m64n256k32.f16.e5m2.e4m3 " + "wgmma.mma_async.sync.aligned.m64n256k32.f16.e4m3.e5m2 " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " " %16, %17, %18, %19, %20, %21, %22, %23, " @@ -45600,19 +16649,19 @@ struct MMA_64x256x32_F16E5M2E4M3_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x256x32_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x256x32_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// GMMA 64x256x32 TN F16+=E5M2*E4M3 +// GMMA 64x256x32 TN F16+=E4M3*E5M2 template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct MMA_64x256x32_F16E5M2E4M3_RS_TN +struct MMA_64x256x32_F16E4M3E5M2_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -45646,7 +16695,7 @@ struct MMA_64x256x32_F16E5M2E4M3_RS_TN "{\n" ".reg .pred p;\n" "setp.ne.b32 p, %69, 0;\n" - "wgmma.mma_async.sync.aligned.m64n256k32.f16.e5m2.e4m3 " + "wgmma.mma_async.sync.aligned.m64n256k32.f16.e4m3.e5m2 " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " " %16, %17, %18, %19, %20, %21, %22, %23, " @@ -45679,19 +16728,19 @@ struct MMA_64x256x32_F16E5M2E4M3_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x256x32_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x256x32_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// GMMA 64x256x32 TN F32+=E5M2*E4M3 +// GMMA 64x256x32 TN F32+=E4M3*E5M2 template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct MMA_64x256x32_F32E5M2E4M3_SS_TN +struct MMA_64x256x32_F32E4M3E5M2_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -45741,7 +16790,7 @@ struct MMA_64x256x32_F32E5M2E4M3_SS_TN "{\n" ".reg .pred p;\n" "setp.ne.b32 p, %130, 0;\n" - "wgmma.mma_async.sync.aligned.m64n256k32.f32.e5m2.e4m3 " + "wgmma.mma_async.sync.aligned.m64n256k32.f32.e4m3.e5m2 " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " " %16, %17, %18, %19, %20, %21, %22, %23, " @@ -45798,19 +16847,19 @@ struct MMA_64x256x32_F32E5M2E4M3_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x256x32_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x256x32_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// GMMA 64x256x32 TN F32+=E5M2*E4M3 +// GMMA 64x256x32 TN F32+=E4M3*E5M2 template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct MMA_64x256x32_F32E5M2E4M3_RS_TN +struct MMA_64x256x32_F32E4M3E5M2_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -45860,7 +16909,7 @@ struct MMA_64x256x32_F32E5M2E4M3_RS_TN "{\n" ".reg .pred p;\n" "setp.ne.b32 p, %133, 0;\n" - "wgmma.mma_async.sync.aligned.m64n256k32.f32.e5m2.e4m3 " + "wgmma.mma_async.sync.aligned.m64n256k32.f32.e4m3.e5m2 " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " " %16, %17, %18, %19, %20, %21, %22, %23, " @@ -45917,19 +16966,19 @@ struct MMA_64x256x32_F32E5M2E4M3_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x256x32_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x256x32_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// GMMA 64x8x32 TN F16+=E5M2*E5M2 +// GMMA 64x8x32 TN F16+=E5M2*E4M3 template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct MMA_64x8x32_F16E5M2E5M2_SS_TN +struct MMA_64x8x32_F16E5M2E4M3_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -45948,7 +16997,7 @@ struct MMA_64x8x32_F16E5M2E5M2_SS_TN "{\n" ".reg .pred p;\n" "setp.ne.b32 p, %4, 0;\n" - "wgmma.mma_async.sync.aligned.m64n8k32.f16.e5m2.e5m2 " + "wgmma.mma_async.sync.aligned.m64n8k32.f16.e5m2.e4m3 " "{%0, %1}," " %2," " %3," @@ -45959,19 +17008,19 @@ struct MMA_64x8x32_F16E5M2E5M2_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x8x32_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x8x32_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// GMMA 64x8x32 TN F16+=E5M2*E5M2 +// GMMA 64x8x32 TN F16+=E5M2*E4M3 template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct MMA_64x8x32_F16E5M2E5M2_RS_TN +struct MMA_64x8x32_F16E5M2E4M3_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -45990,7 +17039,7 @@ struct MMA_64x8x32_F16E5M2E5M2_RS_TN "{\n" ".reg .pred p;\n" "setp.ne.b32 p, %7, 0;\n" - "wgmma.mma_async.sync.aligned.m64n8k32.f16.e5m2.e5m2 " + "wgmma.mma_async.sync.aligned.m64n8k32.f16.e5m2.e4m3 " "{%0, %1}," "{%2, %3, %4, %5}," " %6," @@ -46001,19 +17050,19 @@ struct MMA_64x8x32_F16E5M2E5M2_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x8x32_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x8x32_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// GMMA 64x8x32 TN F32+=E5M2*E5M2 +// GMMA 64x8x32 TN F32+=E5M2*E4M3 template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct MMA_64x8x32_F32E5M2E5M2_SS_TN +struct MMA_64x8x32_F32E5M2E4M3_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -46032,7 +17081,7 @@ struct MMA_64x8x32_F32E5M2E5M2_SS_TN "{\n" ".reg .pred p;\n" "setp.ne.b32 p, %6, 0;\n" - "wgmma.mma_async.sync.aligned.m64n8k32.f32.e5m2.e5m2 " + "wgmma.mma_async.sync.aligned.m64n8k32.f32.e5m2.e4m3 " "{%0, %1, %2, %3}," " %4," " %5," @@ -46043,19 +17092,19 @@ struct MMA_64x8x32_F32E5M2E5M2_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x8x32_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x8x32_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// GMMA 64x8x32 TN F32+=E5M2*E5M2 +// GMMA 64x8x32 TN F32+=E5M2*E4M3 template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct MMA_64x8x32_F32E5M2E5M2_RS_TN +struct MMA_64x8x32_F32E5M2E4M3_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -46074,7 +17123,7 @@ struct MMA_64x8x32_F32E5M2E5M2_RS_TN "{\n" ".reg .pred p;\n" "setp.ne.b32 p, %9, 0;\n" - "wgmma.mma_async.sync.aligned.m64n8k32.f32.e5m2.e5m2 " + "wgmma.mma_async.sync.aligned.m64n8k32.f32.e5m2.e4m3 " "{%0, %1, %2, %3}," "{%4, %5, %6, %7}," " %8," @@ -46085,19 +17134,19 @@ struct MMA_64x8x32_F32E5M2E5M2_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x8x32_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x8x32_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// GMMA 64x16x32 TN F16+=E5M2*E5M2 +// GMMA 64x16x32 TN F16+=E5M2*E4M3 template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct MMA_64x16x32_F16E5M2E5M2_SS_TN +struct MMA_64x16x32_F16E5M2E4M3_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -46116,7 +17165,7 @@ struct MMA_64x16x32_F16E5M2E5M2_SS_TN "{\n" ".reg .pred p;\n" "setp.ne.b32 p, %6, 0;\n" - "wgmma.mma_async.sync.aligned.m64n16k32.f16.e5m2.e5m2 " + "wgmma.mma_async.sync.aligned.m64n16k32.f16.e5m2.e4m3 " "{%0, %1, %2, %3}," " %4," " %5," @@ -46127,19 +17176,19 @@ struct MMA_64x16x32_F16E5M2E5M2_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x16x32_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x16x32_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// GMMA 64x16x32 TN F16+=E5M2*E5M2 +// GMMA 64x16x32 TN F16+=E5M2*E4M3 template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct MMA_64x16x32_F16E5M2E5M2_RS_TN +struct MMA_64x16x32_F16E5M2E4M3_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -46158,7 +17207,7 @@ struct MMA_64x16x32_F16E5M2E5M2_RS_TN "{\n" ".reg .pred p;\n" "setp.ne.b32 p, %9, 0;\n" - "wgmma.mma_async.sync.aligned.m64n16k32.f16.e5m2.e5m2 " + "wgmma.mma_async.sync.aligned.m64n16k32.f16.e5m2.e4m3 " "{%0, %1, %2, %3}," "{%4, %5, %6, %7}," " %8," @@ -46169,19 +17218,19 @@ struct MMA_64x16x32_F16E5M2E5M2_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x16x32_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x16x32_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// GMMA 64x16x32 TN F32+=E5M2*E5M2 +// GMMA 64x16x32 TN F32+=E5M2*E4M3 template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct MMA_64x16x32_F32E5M2E5M2_SS_TN +struct MMA_64x16x32_F32E5M2E4M3_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -46201,7 +17250,7 @@ struct MMA_64x16x32_F32E5M2E5M2_SS_TN "{\n" ".reg .pred p;\n" "setp.ne.b32 p, %10, 0;\n" - "wgmma.mma_async.sync.aligned.m64n16k32.f32.e5m2.e5m2 " + "wgmma.mma_async.sync.aligned.m64n16k32.f32.e5m2.e4m3 " "{%0, %1, %2, %3, %4, %5, %6, %7}," " %8," " %9," @@ -46213,19 +17262,19 @@ struct MMA_64x16x32_F32E5M2E5M2_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x16x32_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x16x32_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// GMMA 64x16x32 TN F32+=E5M2*E5M2 +// GMMA 64x16x32 TN F32+=E5M2*E4M3 template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct MMA_64x16x32_F32E5M2E5M2_RS_TN +struct MMA_64x16x32_F32E5M2E4M3_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -46245,7 +17294,7 @@ struct MMA_64x16x32_F32E5M2E5M2_RS_TN "{\n" ".reg .pred p;\n" "setp.ne.b32 p, %13, 0;\n" - "wgmma.mma_async.sync.aligned.m64n16k32.f32.e5m2.e5m2 " + "wgmma.mma_async.sync.aligned.m64n16k32.f32.e5m2.e4m3 " "{%0, %1, %2, %3, %4, %5, %6, %7}," "{%8, %9, %10, %11}," " %12," @@ -46257,19 +17306,19 @@ struct MMA_64x16x32_F32E5M2E5M2_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x16x32_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x16x32_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// GMMA 64x32x32 TN F16+=E5M2*E5M2 +// GMMA 64x32x32 TN F16+=E5M2*E4M3 template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct MMA_64x32x32_F16E5M2E5M2_SS_TN +struct MMA_64x32x32_F16E5M2E4M3_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -46289,7 +17338,7 @@ struct MMA_64x32x32_F16E5M2E5M2_SS_TN "{\n" ".reg .pred p;\n" "setp.ne.b32 p, %10, 0;\n" - "wgmma.mma_async.sync.aligned.m64n32k32.f16.e5m2.e5m2 " + "wgmma.mma_async.sync.aligned.m64n32k32.f16.e5m2.e4m3 " "{%0, %1, %2, %3, %4, %5, %6, %7}," " %8," " %9," @@ -46301,19 +17350,19 @@ struct MMA_64x32x32_F16E5M2E5M2_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x32x32_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x32x32_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// GMMA 64x32x32 TN F16+=E5M2*E5M2 +// GMMA 64x32x32 TN F16+=E5M2*E4M3 template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct MMA_64x32x32_F16E5M2E5M2_RS_TN +struct MMA_64x32x32_F16E5M2E4M3_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -46333,7 +17382,7 @@ struct MMA_64x32x32_F16E5M2E5M2_RS_TN "{\n" ".reg .pred p;\n" "setp.ne.b32 p, %13, 0;\n" - "wgmma.mma_async.sync.aligned.m64n32k32.f16.e5m2.e5m2 " + "wgmma.mma_async.sync.aligned.m64n32k32.f16.e5m2.e4m3 " "{%0, %1, %2, %3, %4, %5, %6, %7}," "{%8, %9, %10, %11}," " %12," @@ -46345,19 +17394,19 @@ struct MMA_64x32x32_F16E5M2E5M2_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x32x32_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x32x32_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// GMMA 64x32x32 TN F32+=E5M2*E5M2 +// GMMA 64x32x32 TN F32+=E5M2*E4M3 template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct MMA_64x32x32_F32E5M2E5M2_SS_TN +struct MMA_64x32x32_F32E5M2E4M3_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -46379,7 +17428,7 @@ struct MMA_64x32x32_F32E5M2E5M2_SS_TN "{\n" ".reg .pred p;\n" "setp.ne.b32 p, %18, 0;\n" - "wgmma.mma_async.sync.aligned.m64n32k32.f32.e5m2.e5m2 " + "wgmma.mma_async.sync.aligned.m64n32k32.f32.e5m2.e4m3 " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15}," " %16," @@ -46394,19 +17443,19 @@ struct MMA_64x32x32_F32E5M2E5M2_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x32x32_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x32x32_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// GMMA 64x32x32 TN F32+=E5M2*E5M2 +// GMMA 64x32x32 TN F32+=E5M2*E4M3 template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct MMA_64x32x32_F32E5M2E5M2_RS_TN +struct MMA_64x32x32_F32E5M2E4M3_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -46428,7 +17477,7 @@ struct MMA_64x32x32_F32E5M2E5M2_RS_TN "{\n" ".reg .pred p;\n" "setp.ne.b32 p, %21, 0;\n" - "wgmma.mma_async.sync.aligned.m64n32k32.f32.e5m2.e5m2 " + "wgmma.mma_async.sync.aligned.m64n32k32.f32.e5m2.e4m3 " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15}," "{%16, %17, %18, %19}," @@ -46443,229 +17492,19 @@ struct MMA_64x32x32_F32E5M2E5M2_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x32x32_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x48x32 TN F16+=E5M2*E5M2 -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -struct MMA_64x48x32_F16E5M2E5M2_SS_TN -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[12]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %14, 0;\n" - "wgmma.mma_async.sync.aligned.m64n48k32.f16.e5m2.e5m2 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11}," - " %12," - " %13," - " p, %15, %16;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11) - : "l"(desc_a), - "l"(desc_b), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x48x32_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x48x32 TN F16+=E5M2*E5M2 -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -struct MMA_64x48x32_F16E5M2E5M2_RS_TN -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[12]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %17, 0;\n" - "wgmma.mma_async.sync.aligned.m64n48k32.f16.e5m2.e5m2 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11}," - "{%12, %13, %14, %15}," - " %16," - " p, %18, %19;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), - "l"(desc_b), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x48x32_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x48x32 TN F32+=E5M2*E5M2 -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -struct MMA_64x48x32_F32E5M2E5M2_SS_TN -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = float[24]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - float & d00, float & d01, float & d02, float & d03, - float & d04, float & d05, float & d06, float & d07, - float & d08, float & d09, float & d10, float & d11, - float & d12, float & d13, float & d14, float & d15, - float & d16, float & d17, float & d18, float & d19, - float & d20, float & d21, float & d22, float & d23, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %26, 0;\n" - "wgmma.mma_async.sync.aligned.m64n48k32.f32.e5m2.e5m2 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23}," - " %24," - " %25," - " p, %27, %28;\n" - "}\n" - : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), - "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), - "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), - "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), - "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), - "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23) - : "l"(desc_a), - "l"(desc_b), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x48x32_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x48x32 TN F32+=E5M2*E5M2 -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -struct MMA_64x48x32_F32E5M2E5M2_RS_TN -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using BRegisters = uint64_t[1]; - using CRegisters = float[24]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, - uint64_t const& desc_b, - float & d00, float & d01, float & d02, float & d03, - float & d04, float & d05, float & d06, float & d07, - float & d08, float & d09, float & d10, float & d11, - float & d12, float & d13, float & d14, float & d15, - float & d16, float & d17, float & d18, float & d19, - float & d20, float & d21, float & d22, float & d23, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %29, 0;\n" - "wgmma.mma_async.sync.aligned.m64n48k32.f32.e5m2.e5m2 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23}," - "{%24, %25, %26, %27}," - " %28," - " p, %30, %31;\n" - "}\n" - : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), - "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), - "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), - "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), - "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), - "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), - "l"(desc_b), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x48x32_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x32x32_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -// GMMA 64x64x32 TN F16+=E5M2*E5M2 +// GMMA 64x64x32 TN F16+=E5M2*E4M3 template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct MMA_64x64x32_F16E5M2E5M2_SS_TN +struct MMA_64x64x32_F16E5M2E4M3_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -46687,7 +17526,7 @@ struct MMA_64x64x32_F16E5M2E5M2_SS_TN "{\n" ".reg .pred p;\n" "setp.ne.b32 p, %18, 0;\n" - "wgmma.mma_async.sync.aligned.m64n64k32.f16.e5m2.e5m2 " + "wgmma.mma_async.sync.aligned.m64n64k32.f16.e5m2.e4m3 " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15}," " %16," @@ -46702,19 +17541,19 @@ struct MMA_64x64x32_F16E5M2E5M2_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x64x32_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x64x32_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// GMMA 64x64x32 TN F16+=E5M2*E5M2 +// GMMA 64x64x32 TN F16+=E5M2*E4M3 template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct MMA_64x64x32_F16E5M2E5M2_RS_TN +struct MMA_64x64x32_F16E5M2E4M3_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -46736,7 +17575,7 @@ struct MMA_64x64x32_F16E5M2E5M2_RS_TN "{\n" ".reg .pred p;\n" "setp.ne.b32 p, %21, 0;\n" - "wgmma.mma_async.sync.aligned.m64n64k32.f16.e5m2.e5m2 " + "wgmma.mma_async.sync.aligned.m64n64k32.f16.e5m2.e4m3 " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15}," "{%16, %17, %18, %19}," @@ -46751,19 +17590,19 @@ struct MMA_64x64x32_F16E5M2E5M2_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x64x32_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x64x32_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// GMMA 64x64x32 TN F32+=E5M2*E5M2 +// GMMA 64x64x32 TN F32+=E5M2*E4M3 template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct MMA_64x64x32_F32E5M2E5M2_SS_TN +struct MMA_64x64x32_F32E5M2E4M3_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -46789,7 +17628,7 @@ struct MMA_64x64x32_F32E5M2E5M2_SS_TN "{\n" ".reg .pred p;\n" "setp.ne.b32 p, %34, 0;\n" - "wgmma.mma_async.sync.aligned.m64n64k32.f32.e5m2.e5m2 " + "wgmma.mma_async.sync.aligned.m64n64k32.f32.e5m2.e4m3 " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " " %16, %17, %18, %19, %20, %21, %22, %23, " @@ -46810,19 +17649,19 @@ struct MMA_64x64x32_F32E5M2E5M2_SS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x64x32_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x64x32_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// GMMA 64x64x32 TN F32+=E5M2*E5M2 +// GMMA 64x64x32 TN F32+=E5M2*E4M3 template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct MMA_64x64x32_F32E5M2E5M2_RS_TN +struct MMA_64x64x32_F32E5M2E4M3_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -46848,7 +17687,7 @@ struct MMA_64x64x32_F32E5M2E5M2_RS_TN "{\n" ".reg .pred p;\n" "setp.ne.b32 p, %37, 0;\n" - "wgmma.mma_async.sync.aligned.m64n64k32.f32.e5m2.e5m2 " + "wgmma.mma_async.sync.aligned.m64n64k32.f32.e5m2.e4m3 " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " " %16, %17, %18, %19, %20, %21, %22, %23, " @@ -46869,25 +17708,24 @@ struct MMA_64x64x32_F32E5M2E5M2_RS_TN "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x64x32_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x64x32_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x80x32 TN F16+=E5M2*E5M2 +// GMMA 64x96x32 TN F16+=E5M2*E4M3 template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct MMA_64x80x32_F16E5M2E5M2_SS_TN +struct MMA_64x96x32_F16E5M2E4M3_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[20]; + using CRegisters = uint32_t[24]; CUTE_HOST_DEVICE static void fma(uint64_t const& desc_a, @@ -46897,6 +17735,7 @@ struct MMA_64x80x32_F16E5M2E5M2_SS_TN uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) @@ -46904,44 +17743,43 @@ struct MMA_64x80x32_F16E5M2E5M2_SS_TN asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %22, 0;\n" - "wgmma.mma_async.sync.aligned.m64n80k32.f16.e5m2.e5m2 " + "setp.ne.b32 p, %26, 0;\n" + "wgmma.mma_async.sync.aligned.m64n96k32.f16.e5m2.e4m3 " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19}," - " %20," - " %21," - " p, %23, %24;\n" + " %16, %17, %18, %19, %20, %21, %22, %23}," + " %24," + " %25," + " p, %27, %28;\n" "}\n" : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19) + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) : "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x80x32_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x96x32_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x80x32 TN F16+=E5M2*E5M2 +// GMMA 64x96x32 TN F16+=E5M2*E4M3 template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct MMA_64x80x32_F16E5M2E5M2_RS_TN +struct MMA_64x96x32_F16E5M2E4M3_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[20]; + using CRegisters = uint32_t[24]; CUTE_HOST_DEVICE static void fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, @@ -46951,6 +17789,7 @@ struct MMA_64x80x32_F16E5M2E5M2_RS_TN uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) @@ -46958,44 +17797,43 @@ struct MMA_64x80x32_F16E5M2E5M2_RS_TN asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %25, 0;\n" - "wgmma.mma_async.sync.aligned.m64n80k32.f16.e5m2.e5m2 " + "setp.ne.b32 p, %29, 0;\n" + "wgmma.mma_async.sync.aligned.m64n96k32.f16.e5m2.e4m3 " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19}," - "{%20, %21, %22, %23}," - " %24," - " p, %26, %27;\n" + " %16, %17, %18, %19, %20, %21, %22, %23}," + "{%24, %25, %26, %27}," + " %28," + " p, %30, %31;\n" "}\n" : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19) + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) : "r"(a00), "r"(a01), "r"(a02), "r"(a03), "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x80x32_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x96x32_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x80x32 TN F32+=E5M2*E5M2 +// GMMA 64x96x32 TN F32+=E5M2*E4M3 template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct MMA_64x80x32_F32E5M2E5M2_SS_TN +struct MMA_64x96x32_F32E5M2E4M3_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = float[40]; + using CRegisters = float[48]; CUTE_HOST_DEVICE static void fma(uint64_t const& desc_a, @@ -47010,6 +17848,8 @@ struct MMA_64x80x32_F32E5M2E5M2_SS_TN float & d28, float & d29, float & d30, float & d31, float & d32, float & d33, float & d34, float & d35, float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) @@ -47017,16 +17857,17 @@ struct MMA_64x80x32_F32E5M2E5M2_SS_TN asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %42, 0;\n" - "wgmma.mma_async.sync.aligned.m64n80k32.f32.e5m2.e5m2 " + "setp.ne.b32 p, %50, 0;\n" + "wgmma.mma_async.sync.aligned.m64n96k32.f32.e5m2.e4m3 " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " " %16, %17, %18, %19, %20, %21, %22, %23, " " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39}," - " %40," - " %41," - " p, %43, %44;\n" + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + " %48," + " %49," + " p, %51, %52;\n" "}\n" : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), @@ -47037,31 +17878,31 @@ struct MMA_64x80x32_F32E5M2E5M2_SS_TN "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), - "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39) + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47) : "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x80x32_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x96x32_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x80x32 TN F32+=E5M2*E5M2 +// GMMA 64x96x32 TN F32+=E5M2*E4M3 template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct MMA_64x80x32_F32E5M2E5M2_RS_TN +struct MMA_64x96x32_F32E5M2E4M3_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; using BRegisters = uint64_t[1]; - using CRegisters = float[40]; + using CRegisters = float[48]; CUTE_HOST_DEVICE static void fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, @@ -47076,6 +17917,8 @@ struct MMA_64x80x32_F32E5M2E5M2_RS_TN float & d28, float & d29, float & d30, float & d31, float & d32, float & d33, float & d34, float & d35, float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) @@ -47083,16 +17926,17 @@ struct MMA_64x80x32_F32E5M2E5M2_RS_TN asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %45, 0;\n" - "wgmma.mma_async.sync.aligned.m64n80k32.f32.e5m2.e5m2 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39}," - "{%40, %41, %42, %43}," - " %44," - " p, %46, %47;\n" + "setp.ne.b32 p, %53, 0;\n" + "wgmma.mma_async.sync.aligned.m64n96k32.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + "{%48, %49, %50, %51}," + " %52," + " p, %54, %55;\n" "}\n" : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), @@ -47103,30 +17947,31 @@ struct MMA_64x80x32_F32E5M2E5M2_RS_TN "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), - "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39) + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47) : "r"(a00), "r"(a01), "r"(a02), "r"(a03), "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x80x32_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x96x32_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -// GMMA 64x96x32 TN F16+=E5M2*E5M2 +// GMMA 64x128x32 TN F16+=E5M2*E4M3 template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct MMA_64x96x32_F16E5M2E5M2_SS_TN +struct MMA_64x128x32_F16E5M2E4M3_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[24]; + using CRegisters = uint32_t[32]; CUTE_HOST_DEVICE static void fma(uint64_t const& desc_a, @@ -47137,6 +17982,8 @@ struct MMA_64x96x32_F16E5M2E5M2_SS_TN uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) @@ -47144,43 +17991,46 @@ struct MMA_64x96x32_F16E5M2E5M2_SS_TN asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %26, 0;\n" - "wgmma.mma_async.sync.aligned.m64n96k32.f16.e5m2.e5m2 " + "setp.ne.b32 p, %34, 0;\n" + "wgmma.mma_async.sync.aligned.m64n128k32.f16.e5m2.e4m3 " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23}," - " %24," - " %25," - " p, %27, %28;\n" + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + " %32," + " %33," + " p, %35, %36;\n" "}\n" : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) : "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x96x32_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x128x32_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// GMMA 64x96x32 TN F16+=E5M2*E5M2 +// GMMA 64x128x32 TN F16+=E5M2*E4M3 template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct MMA_64x96x32_F16E5M2E5M2_RS_TN +struct MMA_64x128x32_F16E5M2E4M3_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[24]; + using CRegisters = uint32_t[32]; CUTE_HOST_DEVICE static void fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, @@ -47191,6 +18041,8 @@ struct MMA_64x96x32_F16E5M2E5M2_RS_TN uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) @@ -47198,43 +18050,46 @@ struct MMA_64x96x32_F16E5M2E5M2_RS_TN asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %29, 0;\n" - "wgmma.mma_async.sync.aligned.m64n96k32.f16.e5m2.e5m2 " + "setp.ne.b32 p, %37, 0;\n" + "wgmma.mma_async.sync.aligned.m64n128k32.f16.e5m2.e4m3 " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23}," - "{%24, %25, %26, %27}," - " %28," - " p, %30, %31;\n" + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + "{%32, %33, %34, %35}," + " %36," + " p, %38, %39;\n" "}\n" : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) : "r"(a00), "r"(a01), "r"(a02), "r"(a03), "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x96x32_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x128x32_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// GMMA 64x96x32 TN F32+=E5M2*E5M2 +// GMMA 64x128x32 TN F32+=E5M2*E4M3 template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct MMA_64x96x32_F32E5M2E5M2_SS_TN +struct MMA_64x128x32_F32E5M2E4M3_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = float[48]; + using CRegisters = float[64]; CUTE_HOST_DEVICE static void fma(uint64_t const& desc_a, @@ -47251,6 +18106,10 @@ struct MMA_64x96x32_F32E5M2E5M2_SS_TN float & d36, float & d37, float & d38, float & d39, float & d40, float & d41, float & d42, float & d43, float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) @@ -47258,17 +18117,19 @@ struct MMA_64x96x32_F32E5M2E5M2_SS_TN asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %50, 0;\n" - "wgmma.mma_async.sync.aligned.m64n96k32.f32.e5m2.e5m2 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47}," - " %48," - " %49," - " p, %51, %52;\n" + "setp.ne.b32 p, %66, 0;\n" + "wgmma.mma_async.sync.aligned.m64n128k32.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + " %64," + " %65," + " p, %67, %68;\n" "}\n" : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), @@ -47281,29 +18142,33 @@ struct MMA_64x96x32_F32E5M2E5M2_SS_TN "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), - "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47) + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63) : "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x96x32_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x128x32_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// GMMA 64x96x32 TN F32+=E5M2*E5M2 +// GMMA 64x128x32 TN F32+=E5M2*E4M3 template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct MMA_64x96x32_F32E5M2E5M2_RS_TN +struct MMA_64x128x32_F32E5M2E4M3_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; using BRegisters = uint64_t[1]; - using CRegisters = float[48]; + using CRegisters = float[64]; CUTE_HOST_DEVICE static void fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, @@ -47320,6 +18185,10 @@ struct MMA_64x96x32_F32E5M2E5M2_RS_TN float & d36, float & d37, float & d38, float & d39, float & d40, float & d41, float & d42, float & d43, float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) @@ -47327,17 +18196,19 @@ struct MMA_64x96x32_F32E5M2E5M2_RS_TN asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %53, 0;\n" - "wgmma.mma_async.sync.aligned.m64n96k32.f32.e5m2.e5m2 " + "setp.ne.b32 p, %69, 0;\n" + "wgmma.mma_async.sync.aligned.m64n128k32.f32.e5m2.e4m3 " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " " %16, %17, %18, %19, %20, %21, %22, %23, " " %24, %25, %26, %27, %28, %29, %30, %31, " " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47}," - "{%48, %49, %50, %51}," - " %52," - " p, %54, %55;\n" + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + "{%64, %65, %66, %67}," + " %68," + " p, %70, %71;\n" "}\n" : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), @@ -47350,30 +18221,33 @@ struct MMA_64x96x32_F32E5M2E5M2_RS_TN "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), - "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47) + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63) : "r"(a00), "r"(a01), "r"(a02), "r"(a03), "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x96x32_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x128x32_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x112x32 TN F16+=E5M2*E5M2 +// GMMA 64x192x32 TN F16+=E5M2*E4M3 template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct MMA_64x112x32_F16E5M2E5M2_SS_TN +struct MMA_64x192x32_F16E5M2E4M3_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[28]; + using CRegisters = uint32_t[48]; CUTE_HOST_DEVICE static void fma(uint64_t const& desc_a, @@ -47385,6 +18259,11 @@ struct MMA_64x112x32_F16E5M2E5M2_SS_TN uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) @@ -47392,15 +18271,17 @@ struct MMA_64x112x32_F16E5M2E5M2_SS_TN asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %30, 0;\n" - "wgmma.mma_async.sync.aligned.m64n112k32.f16.e5m2.e5m2 " + "setp.ne.b32 p, %50, 0;\n" + "wgmma.mma_async.sync.aligned.m64n192k32.f16.e5m2.e4m3 " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27}," - " %28," - " %29," - " p, %31, %32;\n" + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + " %48," + " %49," + " p, %51, %52;\n" "}\n" : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), @@ -47408,31 +18289,34 @@ struct MMA_64x112x32_F16E5M2E5M2_SS_TN "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27) + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) : "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x112x32_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x192x32_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x112x32 TN F16+=E5M2*E5M2 +// GMMA 64x192x32 TN F16+=E5M2*E4M3 template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct MMA_64x112x32_F16E5M2E5M2_RS_TN +struct MMA_64x192x32_F16E5M2E4M3_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[28]; + using CRegisters = uint32_t[48]; CUTE_HOST_DEVICE static void fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, @@ -47444,6 +18328,11 @@ struct MMA_64x112x32_F16E5M2E5M2_RS_TN uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) @@ -47451,15 +18340,17 @@ struct MMA_64x112x32_F16E5M2E5M2_RS_TN asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %33, 0;\n" - "wgmma.mma_async.sync.aligned.m64n112k32.f16.e5m2.e5m2 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27}," - "{%28, %29, %30, %31}," - " %32," - " p, %34, %35;\n" + "setp.ne.b32 p, %53, 0;\n" + "wgmma.mma_async.sync.aligned.m64n192k32.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + "{%48, %49, %50, %51}," + " %52," + " p, %54, %55;\n" "}\n" : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), @@ -47467,31 +18358,34 @@ struct MMA_64x112x32_F16E5M2E5M2_RS_TN "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27) + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) : "r"(a00), "r"(a01), "r"(a02), "r"(a03), "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x112x32_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x192x32_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x112x32 TN F32+=E5M2*E5M2 +// GMMA 64x192x32 TN F32+=E5M2*E4M3 template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct MMA_64x112x32_F32E5M2E5M2_SS_TN +struct MMA_64x192x32_F32E5M2E4M3_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = float[56]; + using CRegisters = float[96]; CUTE_HOST_DEVICE static void fma(uint64_t const& desc_a, @@ -47510,6 +18404,16 @@ struct MMA_64x112x32_F32E5M2E5M2_SS_TN float & d44, float & d45, float & d46, float & d47, float & d48, float & d49, float & d50, float & d51, float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + float & d88, float & d89, float & d90, float & d91, + float & d92, float & d93, float & d94, float & d95, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) @@ -47517,18 +18421,23 @@ struct MMA_64x112x32_F32E5M2E5M2_SS_TN asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %58, 0;\n" - "wgmma.mma_async.sync.aligned.m64n112k32.f32.e5m2.e5m2 " + "setp.ne.b32 p, %98, 0;\n" + "wgmma.mma_async.sync.aligned.m64n192k32.f32.e5m2.e4m3 " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " " %16, %17, %18, %19, %20, %21, %22, %23, " " %24, %25, %26, %27, %28, %29, %30, %31, " " %32, %33, %34, %35, %36, %37, %38, %39, " " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55}," - " %56," - " %57," - " p, %59, %60;\n" + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + " %96," + " %97," + " p, %99, %100;\n" "}\n" : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), @@ -47543,31 +18452,39 @@ struct MMA_64x112x32_F32E5M2E5M2_SS_TN "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), - "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55) + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), + "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91), + "+f"(d92), "+f"(d93), "+f"(d94), "+f"(d95) : "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x112x32_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x192x32_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x112x32 TN F32+=E5M2*E5M2 +// GMMA 64x192x32 TN F32+=E5M2*E4M3 template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct MMA_64x112x32_F32E5M2E5M2_RS_TN +struct MMA_64x192x32_F32E5M2E4M3_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; using BRegisters = uint64_t[1]; - using CRegisters = float[56]; + using CRegisters = float[96]; CUTE_HOST_DEVICE static void fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, @@ -47586,6 +18503,16 @@ struct MMA_64x112x32_F32E5M2E5M2_RS_TN float & d44, float & d45, float & d46, float & d47, float & d48, float & d49, float & d50, float & d51, float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + float & d88, float & d89, float & d90, float & d91, + float & d92, float & d93, float & d94, float & d95, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) @@ -47593,18 +18520,23 @@ struct MMA_64x112x32_F32E5M2E5M2_RS_TN asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %61, 0;\n" - "wgmma.mma_async.sync.aligned.m64n112k32.f32.e5m2.e5m2 " + "setp.ne.b32 p, %101, 0;\n" + "wgmma.mma_async.sync.aligned.m64n192k32.f32.e5m2.e4m3 " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " " %16, %17, %18, %19, %20, %21, %22, %23, " " %24, %25, %26, %27, %28, %29, %30, %31, " " %32, %33, %34, %35, %36, %37, %38, %39, " " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55}," - "{%56, %57, %58, %59}," - " %60," - " p, %62, %63;\n" + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + "{%96, %97, %98, %99}," + " %100," + " p, %102, %103;\n" "}\n" : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), @@ -47619,30 +18551,39 @@ struct MMA_64x112x32_F32E5M2E5M2_RS_TN "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), - "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55) + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), + "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91), + "+f"(d92), "+f"(d93), "+f"(d94), "+f"(d95) : "r"(a00), "r"(a01), "r"(a02), "r"(a03), "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x112x32_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x192x32_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -// GMMA 64x128x32 TN F16+=E5M2*E5M2 +// GMMA 64x256x32 TN F16+=E5M2*E4M3 template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct MMA_64x128x32_F16E5M2E5M2_SS_TN +struct MMA_64x256x32_F16E5M2E4M3_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[32]; + using CRegisters = uint32_t[64]; CUTE_HOST_DEVICE static void fma(uint64_t const& desc_a, @@ -47655,6 +18596,14 @@ struct MMA_64x128x32_F16E5M2E5M2_SS_TN uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) @@ -47662,15 +18611,19 @@ struct MMA_64x128x32_F16E5M2E5M2_SS_TN asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %34, 0;\n" - "wgmma.mma_async.sync.aligned.m64n128k32.f16.e5m2.e5m2 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31}," - " %32," - " %33," - " p, %35, %36;\n" + "setp.ne.b32 p, %66, 0;\n" + "wgmma.mma_async.sync.aligned.m64n256k32.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + " %64," + " %65," + " p, %67, %68;\n" "}\n" : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), @@ -47679,29 +18632,37 @@ struct MMA_64x128x32_F16E5M2E5M2_SS_TN "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) : "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x128x32_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x256x32_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// GMMA 64x128x32 TN F16+=E5M2*E5M2 +// GMMA 64x256x32 TN F16+=E5M2*E4M3 template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct MMA_64x128x32_F16E5M2E5M2_RS_TN +struct MMA_64x256x32_F16E5M2E4M3_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[32]; + using CRegisters = uint32_t[64]; CUTE_HOST_DEVICE static void fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, @@ -47714,6 +18675,14 @@ struct MMA_64x128x32_F16E5M2E5M2_RS_TN uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) @@ -47721,15 +18690,19 @@ struct MMA_64x128x32_F16E5M2E5M2_RS_TN asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %37, 0;\n" - "wgmma.mma_async.sync.aligned.m64n128k32.f16.e5m2.e5m2 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31}," - "{%32, %33, %34, %35}," - " %36," - " p, %38, %39;\n" + "setp.ne.b32 p, %69, 0;\n" + "wgmma.mma_async.sync.aligned.m64n256k32.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + "{%64, %65, %66, %67}," + " %68," + " p, %70, %71;\n" "}\n" : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), @@ -47738,49 +18711,73 @@ struct MMA_64x128x32_F16E5M2E5M2_RS_TN "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) : "r"(a00), "r"(a01), "r"(a02), "r"(a03), "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x128x32_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x256x32_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// GMMA 64x128x32 TN F32+=E5M2*E5M2 +// GMMA 64x256x32 TN F32+=E5M2*E4M3 template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct MMA_64x128x32_F32E5M2E5M2_SS_TN +struct MMA_64x256x32_F32E5M2E4M3_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = float[64]; + using CRegisters = float[128]; CUTE_HOST_DEVICE static void fma(uint64_t const& desc_a, uint64_t const& desc_b, - float & d00, float & d01, float & d02, float & d03, - float & d04, float & d05, float & d06, float & d07, - float & d08, float & d09, float & d10, float & d11, - float & d12, float & d13, float & d14, float & d15, - float & d16, float & d17, float & d18, float & d19, - float & d20, float & d21, float & d22, float & d23, - float & d24, float & d25, float & d26, float & d27, - float & d28, float & d29, float & d30, float & d31, - float & d32, float & d33, float & d34, float & d35, - float & d36, float & d37, float & d38, float & d39, - float & d40, float & d41, float & d42, float & d43, - float & d44, float & d45, float & d46, float & d47, - float & d48, float & d49, float & d50, float & d51, - float & d52, float & d53, float & d54, float & d55, - float & d56, float & d57, float & d58, float & d59, - float & d60, float & d61, float & d62, float & d63, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + float & d120, float & d121, float & d122, float & d123, + float & d124, float & d125, float & d126, float & d127, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) @@ -47788,8 +18785,8 @@ struct MMA_64x128x32_F32E5M2E5M2_SS_TN asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %66, 0;\n" - "wgmma.mma_async.sync.aligned.m64n128k32.f32.e5m2.e5m2 " + "setp.ne.b32 p, %130, 0;\n" + "wgmma.mma_async.sync.aligned.m64n256k32.f32.e5m2.e4m3 " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " " %16, %17, %18, %19, %20, %21, %22, %23, " @@ -47797,69 +18794,109 @@ struct MMA_64x128x32_F32E5M2E5M2_SS_TN " %32, %33, %34, %35, %36, %37, %38, %39, " " %40, %41, %42, %43, %44, %45, %46, %47, " " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63}," - " %64," - " %65," - " p, %67, %68;\n" + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + " %128," + " %129," + " p, %131, %132;\n" "}\n" - : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), - "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), - "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), - "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), - "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), - "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), - "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), - "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), - "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), - "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), - "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), - "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), - "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), - "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), - "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), - "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63) + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), + "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123), + "+f"(d124), "+f"(d125), "+f"(d126), "+f"(d127) : "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x128x32_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x256x32_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// GMMA 64x128x32 TN F32+=E5M2*E5M2 +// GMMA 64x256x32 TN F32+=E5M2*E4M3 template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct MMA_64x128x32_F32E5M2E5M2_RS_TN +struct MMA_64x256x32_F32E5M2E4M3_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; using BRegisters = uint64_t[1]; - using CRegisters = float[64]; + using CRegisters = float[128]; CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, uint64_t const& desc_b, - float & d00, float & d01, float & d02, float & d03, - float & d04, float & d05, float & d06, float & d07, - float & d08, float & d09, float & d10, float & d11, - float & d12, float & d13, float & d14, float & d15, - float & d16, float & d17, float & d18, float & d19, - float & d20, float & d21, float & d22, float & d23, - float & d24, float & d25, float & d26, float & d27, - float & d28, float & d29, float & d30, float & d31, - float & d32, float & d33, float & d34, float & d35, - float & d36, float & d37, float & d38, float & d39, - float & d40, float & d41, float & d42, float & d43, - float & d44, float & d45, float & d46, float & d47, - float & d48, float & d49, float & d50, float & d51, - float & d52, float & d53, float & d54, float & d55, - float & d56, float & d57, float & d58, float & d59, - float & d60, float & d61, float & d62, float & d63, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + float & d120, float & d121, float & d122, float & d123, + float & d124, float & d125, float & d126, float & d127, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) @@ -47867,8 +18904,8 @@ struct MMA_64x128x32_F32E5M2E5M2_RS_TN asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %69, 0;\n" - "wgmma.mma_async.sync.aligned.m64n128k32.f32.e5m2.e5m2 " + "setp.ne.b32 p, %133, 0;\n" + "wgmma.mma_async.sync.aligned.m64n256k32.f32.e5m2.e4m3 " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " " %16, %17, %18, %19, %20, %21, %22, %23, " @@ -47876,63 +18913,78 @@ struct MMA_64x128x32_F32E5M2E5M2_RS_TN " %32, %33, %34, %35, %36, %37, %38, %39, " " %40, %41, %42, %43, %44, %45, %46, %47, " " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63}," - "{%64, %65, %66, %67}," - " %68," - " p, %70, %71;\n" + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + "{%128, %129, %130, %131}," + " %132," + " p, %134, %135;\n" "}\n" - : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), - "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), - "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), - "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), - "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), - "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), - "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), - "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), - "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), - "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), - "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), - "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), - "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), - "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), - "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), - "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), + "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123), + "+f"(d124), "+f"(d125), "+f"(d126), "+f"(d127) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x128x32_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x256x32_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x144x32 TN F16+=E5M2*E5M2 +// GMMA 64x8x32 TN F16+=E5M2*E5M2 template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct MMA_64x144x32_F16E5M2E5M2_SS_TN +struct MMA_64x8x32_F16E5M2E5M2_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[36]; + using CRegisters = uint32_t[2]; CUTE_HOST_DEVICE static void fma(uint64_t const& desc_a, uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d0, uint32_t & d1, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) @@ -47940,63 +18992,41 @@ struct MMA_64x144x32_F16E5M2E5M2_SS_TN asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %38, 0;\n" - "wgmma.mma_async.sync.aligned.m64n144k32.f16.e5m2.e5m2 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35}," - " %36," - " %37," - " p, %39, %40;\n" + "setp.ne.b32 p, %4, 0;\n" + "wgmma.mma_async.sync.aligned.m64n8k32.f16.e5m2.e5m2 " + "{%0, %1}," + " %2," + " %3," + " p, %5, %6;\n" "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35) + : "+r"(d0), "+r"(d1) : "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x144x32_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x8x32_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x144x32 TN F16+=E5M2*E5M2 +// GMMA 64x8x32 TN F16+=E5M2*E5M2 template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct MMA_64x144x32_F16E5M2E5M2_RS_TN +struct MMA_64x8x32_F16E5M2E5M2_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[36]; + using CRegisters = uint32_t[2]; CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d0, uint32_t & d1, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) @@ -48004,72 +19034,41 @@ struct MMA_64x144x32_F16E5M2E5M2_RS_TN asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %41, 0;\n" - "wgmma.mma_async.sync.aligned.m64n144k32.f16.e5m2.e5m2 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35}," - "{%36, %37, %38, %39}," - " %40," - " p, %42, %43;\n" + "setp.ne.b32 p, %7, 0;\n" + "wgmma.mma_async.sync.aligned.m64n8k32.f16.e5m2.e5m2 " + "{%0, %1}," + "{%2, %3, %4, %5}," + " %6," + " p, %8, %9;\n" "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + : "+r"(d0), "+r"(d1) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x144x32_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x8x32_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x144x32 TN F32+=E5M2*E5M2 +// GMMA 64x8x32 TN F32+=E5M2*E5M2 template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct MMA_64x144x32_F32E5M2E5M2_SS_TN +struct MMA_64x8x32_F32E5M2E5M2_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = float[72]; + using CRegisters = float[4]; CUTE_HOST_DEVICE static void fma(uint64_t const& desc_a, uint64_t const& desc_b, - float & d00, float & d01, float & d02, float & d03, - float & d04, float & d05, float & d06, float & d07, - float & d08, float & d09, float & d10, float & d11, - float & d12, float & d13, float & d14, float & d15, - float & d16, float & d17, float & d18, float & d19, - float & d20, float & d21, float & d22, float & d23, - float & d24, float & d25, float & d26, float & d27, - float & d28, float & d29, float & d30, float & d31, - float & d32, float & d33, float & d34, float & d35, - float & d36, float & d37, float & d38, float & d39, - float & d40, float & d41, float & d42, float & d43, - float & d44, float & d45, float & d46, float & d47, - float & d48, float & d49, float & d50, float & d51, - float & d52, float & d53, float & d54, float & d55, - float & d56, float & d57, float & d58, float & d59, - float & d60, float & d61, float & d62, float & d63, - float & d64, float & d65, float & d66, float & d67, - float & d68, float & d69, float & d70, float & d71, + float & d0, float & d1, float & d2, float & d3, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) @@ -48077,85 +19076,41 @@ struct MMA_64x144x32_F32E5M2E5M2_SS_TN asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %74, 0;\n" - "wgmma.mma_async.sync.aligned.m64n144k32.f32.e5m2.e5m2 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71}," - " %72," - " %73," - " p, %75, %76;\n" + "setp.ne.b32 p, %6, 0;\n" + "wgmma.mma_async.sync.aligned.m64n8k32.f32.e5m2.e5m2 " + "{%0, %1, %2, %3}," + " %4," + " %5," + " p, %7, %8;\n" "}\n" - : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), - "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), - "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), - "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), - "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), - "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), - "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), - "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), - "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), - "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), - "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), - "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), - "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), - "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), - "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), - "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), - "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), - "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71) + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3) : "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x144x32_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x8x32_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x144x32 TN F32+=E5M2*E5M2 +// GMMA 64x8x32 TN F32+=E5M2*E5M2 template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct MMA_64x144x32_F32E5M2E5M2_RS_TN +struct MMA_64x8x32_F32E5M2E5M2_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; using BRegisters = uint64_t[1]; - using CRegisters = float[72]; + using CRegisters = float[4]; CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, uint64_t const& desc_b, - float & d00, float & d01, float & d02, float & d03, - float & d04, float & d05, float & d06, float & d07, - float & d08, float & d09, float & d10, float & d11, - float & d12, float & d13, float & d14, float & d15, - float & d16, float & d17, float & d18, float & d19, - float & d20, float & d21, float & d22, float & d23, - float & d24, float & d25, float & d26, float & d27, - float & d28, float & d29, float & d30, float & d31, - float & d32, float & d33, float & d34, float & d35, - float & d36, float & d37, float & d38, float & d39, - float & d40, float & d41, float & d42, float & d43, - float & d44, float & d45, float & d46, float & d47, - float & d48, float & d49, float & d50, float & d51, - float & d52, float & d53, float & d54, float & d55, - float & d56, float & d57, float & d58, float & d59, - float & d60, float & d61, float & d62, float & d63, - float & d64, float & d65, float & d66, float & d67, - float & d68, float & d69, float & d70, float & d71, + float & d0, float & d1, float & d2, float & d3, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) @@ -48163,77 +19118,41 @@ struct MMA_64x144x32_F32E5M2E5M2_RS_TN asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %77, 0;\n" - "wgmma.mma_async.sync.aligned.m64n144k32.f32.e5m2.e5m2 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71}," - "{%72, %73, %74, %75}," - " %76," - " p, %78, %79;\n" + "setp.ne.b32 p, %9, 0;\n" + "wgmma.mma_async.sync.aligned.m64n8k32.f32.e5m2.e5m2 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + " %8," + " p, %10, %11;\n" "}\n" - : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), - "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), - "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), - "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), - "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), - "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), - "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), - "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), - "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), - "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), - "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), - "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), - "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), - "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), - "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), - "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), - "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), - "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x144x32_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x8x32_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x160x32 TN F16+=E5M2*E5M2 +// GMMA 64x16x32 TN F16+=E5M2*E5M2 template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct MMA_64x160x32_F16E5M2E5M2_SS_TN +struct MMA_64x16x32_F16E5M2E5M2_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[40]; + using CRegisters = uint32_t[4]; CUTE_HOST_DEVICE static void fma(uint64_t const& desc_a, uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) @@ -48241,65 +19160,41 @@ struct MMA_64x160x32_F16E5M2E5M2_SS_TN asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %42, 0;\n" - "wgmma.mma_async.sync.aligned.m64n160k32.f16.e5m2.e5m2 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39}," - " %40," - " %41," - " p, %43, %44;\n" + "setp.ne.b32 p, %6, 0;\n" + "wgmma.mma_async.sync.aligned.m64n16k32.f16.e5m2.e5m2 " + "{%0, %1, %2, %3}," + " %4," + " %5," + " p, %7, %8;\n" "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) : "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x160x32_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x16x32_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x160x32 TN F16+=E5M2*E5M2 +// GMMA 64x16x32 TN F16+=E5M2*E5M2 template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct MMA_64x160x32_F16E5M2E5M2_RS_TN +struct MMA_64x16x32_F16E5M2E5M2_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[40]; + using CRegisters = uint32_t[4]; CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) @@ -48307,75 +19202,42 @@ struct MMA_64x160x32_F16E5M2E5M2_RS_TN asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %45, 0;\n" - "wgmma.mma_async.sync.aligned.m64n160k32.f16.e5m2.e5m2 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39}," - "{%40, %41, %42, %43}," - " %44," - " p, %46, %47;\n" + "setp.ne.b32 p, %9, 0;\n" + "wgmma.mma_async.sync.aligned.m64n16k32.f16.e5m2.e5m2 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + " %8," + " p, %10, %11;\n" "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x160x32_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x16x32_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x160x32 TN F32+=E5M2*E5M2 +// GMMA 64x16x32 TN F32+=E5M2*E5M2 template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct MMA_64x160x32_F32E5M2E5M2_SS_TN +struct MMA_64x16x32_F32E5M2E5M2_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = float[80]; + using CRegisters = float[8]; CUTE_HOST_DEVICE static void fma(uint64_t const& desc_a, uint64_t const& desc_b, - float & d00, float & d01, float & d02, float & d03, - float & d04, float & d05, float & d06, float & d07, - float & d08, float & d09, float & d10, float & d11, - float & d12, float & d13, float & d14, float & d15, - float & d16, float & d17, float & d18, float & d19, - float & d20, float & d21, float & d22, float & d23, - float & d24, float & d25, float & d26, float & d27, - float & d28, float & d29, float & d30, float & d31, - float & d32, float & d33, float & d34, float & d35, - float & d36, float & d37, float & d38, float & d39, - float & d40, float & d41, float & d42, float & d43, - float & d44, float & d45, float & d46, float & d47, - float & d48, float & d49, float & d50, float & d51, - float & d52, float & d53, float & d54, float & d55, - float & d56, float & d57, float & d58, float & d59, - float & d60, float & d61, float & d62, float & d63, - float & d64, float & d65, float & d66, float & d67, - float & d68, float & d69, float & d70, float & d71, - float & d72, float & d73, float & d74, float & d75, - float & d76, float & d77, float & d78, float & d79, + float & d0, float & d1, float & d2, float & d3, + float & d4, float & d5, float & d6, float & d7, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) @@ -48383,90 +19245,43 @@ struct MMA_64x160x32_F32E5M2E5M2_SS_TN asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %82, 0;\n" - "wgmma.mma_async.sync.aligned.m64n160k32.f32.e5m2.e5m2 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79}," - " %80," - " %81," - " p, %83, %84;\n" + "setp.ne.b32 p, %10, 0;\n" + "wgmma.mma_async.sync.aligned.m64n16k32.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + " %8," + " %9," + " p, %11, %12;\n" "}\n" - : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), - "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), - "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), - "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), - "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), - "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), - "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), - "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), - "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), - "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), - "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), - "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), - "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), - "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), - "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), - "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), - "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), - "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), - "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), - "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79) + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3), + "+f"(d4), "+f"(d5), "+f"(d6), "+f"(d7) : "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x160x32_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x16x32_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x160x32 TN F32+=E5M2*E5M2 +// GMMA 64x16x32 TN F32+=E5M2*E5M2 template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct MMA_64x160x32_F32E5M2E5M2_RS_TN +struct MMA_64x16x32_F32E5M2E5M2_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; using BRegisters = uint64_t[1]; - using CRegisters = float[80]; + using CRegisters = float[8]; CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, uint64_t const& desc_b, - float & d00, float & d01, float & d02, float & d03, - float & d04, float & d05, float & d06, float & d07, - float & d08, float & d09, float & d10, float & d11, - float & d12, float & d13, float & d14, float & d15, - float & d16, float & d17, float & d18, float & d19, - float & d20, float & d21, float & d22, float & d23, - float & d24, float & d25, float & d26, float & d27, - float & d28, float & d29, float & d30, float & d31, - float & d32, float & d33, float & d34, float & d35, - float & d36, float & d37, float & d38, float & d39, - float & d40, float & d41, float & d42, float & d43, - float & d44, float & d45, float & d46, float & d47, - float & d48, float & d49, float & d50, float & d51, - float & d52, float & d53, float & d54, float & d55, - float & d56, float & d57, float & d58, float & d59, - float & d60, float & d61, float & d62, float & d63, - float & d64, float & d65, float & d66, float & d67, - float & d68, float & d69, float & d70, float & d71, - float & d72, float & d73, float & d74, float & d75, - float & d76, float & d77, float & d78, float & d79, + float & d0, float & d1, float & d2, float & d3, + float & d4, float & d5, float & d6, float & d7, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) @@ -48474,81 +19289,43 @@ struct MMA_64x160x32_F32E5M2E5M2_RS_TN asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %85, 0;\n" - "wgmma.mma_async.sync.aligned.m64n160k32.f32.e5m2.e5m2 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79}," - "{%80, %81, %82, %83}," - " %84," - " p, %86, %87;\n" + "setp.ne.b32 p, %13, 0;\n" + "wgmma.mma_async.sync.aligned.m64n16k32.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + "{%8, %9, %10, %11}," + " %12," + " p, %14, %15;\n" "}\n" - : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), - "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), - "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), - "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), - "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), - "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), - "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), - "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), - "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), - "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), - "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), - "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), - "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), - "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), - "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), - "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), - "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), - "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), - "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), - "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3), + "+f"(d4), "+f"(d5), "+f"(d6), "+f"(d7) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x160x32_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x16x32_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x176x32 TN F16+=E5M2*E5M2 +// GMMA 64x32x32 TN F16+=E5M2*E5M2 template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct MMA_64x176x32_F16E5M2E5M2_SS_TN +struct MMA_64x32x32_F16E5M2E5M2_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[44]; + using CRegisters = uint32_t[8]; CUTE_HOST_DEVICE static void fma(uint64_t const& desc_a, uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) @@ -48556,68 +19333,43 @@ struct MMA_64x176x32_F16E5M2E5M2_SS_TN asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %46, 0;\n" - "wgmma.mma_async.sync.aligned.m64n176k32.f16.e5m2.e5m2 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43}," - " %44," - " %45," - " p, %47, %48;\n" + "setp.ne.b32 p, %10, 0;\n" + "wgmma.mma_async.sync.aligned.m64n32k32.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + " %8," + " %9," + " p, %11, %12;\n" "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43) + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) : "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x176x32_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x32x32_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x176x32 TN F16+=E5M2*E5M2 +// GMMA 64x32x32 TN F16+=E5M2*E5M2 template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct MMA_64x176x32_F16E5M2E5M2_RS_TN +struct MMA_64x32x32_F16E5M2E5M2_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[44]; + using CRegisters = uint32_t[8]; CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) @@ -48625,53 +19377,37 @@ struct MMA_64x176x32_F16E5M2E5M2_RS_TN asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %49, 0;\n" - "wgmma.mma_async.sync.aligned.m64n176k32.f16.e5m2.e5m2 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43}," - "{%44, %45, %46, %47}," - " %48," - " p, %50, %51;\n" + "setp.ne.b32 p, %13, 0;\n" + "wgmma.mma_async.sync.aligned.m64n32k32.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + "{%8, %9, %10, %11}," + " %12," + " p, %14, %15;\n" "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x176x32_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x32x32_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x176x32 TN F32+=E5M2*E5M2 +// GMMA 64x32x32 TN F32+=E5M2*E5M2 template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct MMA_64x176x32_F32E5M2E5M2_SS_TN +struct MMA_64x32x32_F32E5M2E5M2_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = float[88]; + using CRegisters = float[16]; CUTE_HOST_DEVICE static void fma(uint64_t const& desc_a, @@ -48680,24 +19416,6 @@ struct MMA_64x176x32_F32E5M2E5M2_SS_TN float & d04, float & d05, float & d06, float & d07, float & d08, float & d09, float & d10, float & d11, float & d12, float & d13, float & d14, float & d15, - float & d16, float & d17, float & d18, float & d19, - float & d20, float & d21, float & d22, float & d23, - float & d24, float & d25, float & d26, float & d27, - float & d28, float & d29, float & d30, float & d31, - float & d32, float & d33, float & d34, float & d35, - float & d36, float & d37, float & d38, float & d39, - float & d40, float & d41, float & d42, float & d43, - float & d44, float & d45, float & d46, float & d47, - float & d48, float & d49, float & d50, float & d51, - float & d52, float & d53, float & d54, float & d55, - float & d56, float & d57, float & d58, float & d59, - float & d60, float & d61, float & d62, float & d63, - float & d64, float & d65, float & d66, float & d67, - float & d68, float & d69, float & d70, float & d71, - float & d72, float & d73, float & d74, float & d75, - float & d76, float & d77, float & d78, float & d79, - float & d80, float & d81, float & d82, float & d83, - float & d84, float & d85, float & d86, float & d87, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) @@ -48705,69 +19423,40 @@ struct MMA_64x176x32_F32E5M2E5M2_SS_TN asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %90, 0;\n" - "wgmma.mma_async.sync.aligned.m64n176k32.f32.e5m2.e5m2 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87}," - " %88," - " %89," - " p, %91, %92;\n" + "setp.ne.b32 p, %18, 0;\n" + "wgmma.mma_async.sync.aligned.m64n32k32.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + " %16," + " %17," + " p, %19, %20;\n" "}\n" : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), - "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), - "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), - "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), - "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), - "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), - "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), - "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), - "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), - "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), - "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), - "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), - "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), - "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), - "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), - "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), - "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), - "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), - "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), - "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87) + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15) : "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x176x32_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x32x32_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x176x32 TN F32+=E5M2*E5M2 +// GMMA 64x32x32 TN F32+=E5M2*E5M2 template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct MMA_64x176x32_F32E5M2E5M2_RS_TN +struct MMA_64x32x32_F32E5M2E5M2_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; using BRegisters = uint64_t[1]; - using CRegisters = float[88]; + using CRegisters = float[16]; CUTE_HOST_DEVICE static void fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, @@ -48776,24 +19465,6 @@ struct MMA_64x176x32_F32E5M2E5M2_RS_TN float & d04, float & d05, float & d06, float & d07, float & d08, float & d09, float & d10, float & d11, float & d12, float & d13, float & d14, float & d15, - float & d16, float & d17, float & d18, float & d19, - float & d20, float & d21, float & d22, float & d23, - float & d24, float & d25, float & d26, float & d27, - float & d28, float & d29, float & d30, float & d31, - float & d32, float & d33, float & d34, float & d35, - float & d36, float & d37, float & d38, float & d39, - float & d40, float & d41, float & d42, float & d43, - float & d44, float & d45, float & d46, float & d47, - float & d48, float & d49, float & d50, float & d51, - float & d52, float & d53, float & d54, float & d55, - float & d56, float & d57, float & d58, float & d59, - float & d60, float & d61, float & d62, float & d63, - float & d64, float & d65, float & d66, float & d67, - float & d68, float & d69, float & d70, float & d71, - float & d72, float & d73, float & d74, float & d75, - float & d76, float & d77, float & d78, float & d79, - float & d80, float & d81, float & d82, float & d83, - float & d84, float & d85, float & d86, float & d87, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) @@ -48801,68 +19472,40 @@ struct MMA_64x176x32_F32E5M2E5M2_RS_TN asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %93, 0;\n" - "wgmma.mma_async.sync.aligned.m64n176k32.f32.e5m2.e5m2 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87}," - "{%88, %89, %90, %91}," - " %92," - " p, %94, %95;\n" + "setp.ne.b32 p, %21, 0;\n" + "wgmma.mma_async.sync.aligned.m64n32k32.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + "{%16, %17, %18, %19}," + " %20," + " p, %22, %23;\n" "}\n" : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), - "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), - "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), - "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), - "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), - "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), - "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), - "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), - "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), - "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), - "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), - "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), - "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), - "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), - "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), - "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), - "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), - "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), - "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), - "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87) + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15) : "r"(a00), "r"(a01), "r"(a02), "r"(a03), "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x176x32_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x32x32_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -// GMMA 64x192x32 TN F16+=E5M2*E5M2 +// GMMA 64x64x32 TN F16+=E5M2*E5M2 template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct MMA_64x192x32_F16E5M2E5M2_SS_TN +struct MMA_64x64x32_F16E5M2E5M2_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[48]; + using CRegisters = uint32_t[16]; CUTE_HOST_DEVICE static void fma(uint64_t const& desc_a, @@ -48871,14 +19514,6 @@ struct MMA_64x192x32_F16E5M2E5M2_SS_TN uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) @@ -48886,52 +19521,40 @@ struct MMA_64x192x32_F16E5M2E5M2_SS_TN asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %50, 0;\n" - "wgmma.mma_async.sync.aligned.m64n192k32.f16.e5m2.e5m2 " + "setp.ne.b32 p, %18, 0;\n" + "wgmma.mma_async.sync.aligned.m64n64k32.f16.e5m2.e5m2 " "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47}," - " %48," - " %49," - " p, %51, %52;\n" + " %8, %9, %10, %11, %12, %13, %14, %15}," + " %16," + " %17," + " p, %19, %20;\n" "}\n" : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), - "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) : "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x192x32_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x64x32_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// GMMA 64x192x32 TN F16+=E5M2*E5M2 +// GMMA 64x64x32 TN F16+=E5M2*E5M2 template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct MMA_64x192x32_F16E5M2E5M2_RS_TN +struct MMA_64x64x32_F16E5M2E5M2_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[48]; + using CRegisters = uint32_t[16]; CUTE_HOST_DEVICE static void fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, @@ -48940,14 +19563,6 @@ struct MMA_64x192x32_F16E5M2E5M2_RS_TN uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) @@ -48955,80 +19570,52 @@ struct MMA_64x192x32_F16E5M2E5M2_RS_TN asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %53, 0;\n" - "wgmma.mma_async.sync.aligned.m64n192k32.f16.e5m2.e5m2 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47}," - "{%48, %49, %50, %51}," - " %52," - " p, %54, %55;\n" + "setp.ne.b32 p, %21, 0;\n" + "wgmma.mma_async.sync.aligned.m64n64k32.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + "{%16, %17, %18, %19}," + " %20," + " p, %22, %23;\n" "}\n" : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), - "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) : "r"(a00), "r"(a01), "r"(a02), "r"(a03), "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x192x32_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x64x32_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// GMMA 64x192x32 TN F32+=E5M2*E5M2 +// GMMA 64x64x32 TN F32+=E5M2*E5M2 template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct MMA_64x192x32_F32E5M2E5M2_SS_TN +struct MMA_64x64x32_F32E5M2E5M2_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = float[96]; + using CRegisters = float[32]; CUTE_HOST_DEVICE static void fma(uint64_t const& desc_a, uint64_t const& desc_b, - float & d00, float & d01, float & d02, float & d03, - float & d04, float & d05, float & d06, float & d07, - float & d08, float & d09, float & d10, float & d11, - float & d12, float & d13, float & d14, float & d15, - float & d16, float & d17, float & d18, float & d19, - float & d20, float & d21, float & d22, float & d23, - float & d24, float & d25, float & d26, float & d27, - float & d28, float & d29, float & d30, float & d31, - float & d32, float & d33, float & d34, float & d35, - float & d36, float & d37, float & d38, float & d39, - float & d40, float & d41, float & d42, float & d43, - float & d44, float & d45, float & d46, float & d47, - float & d48, float & d49, float & d50, float & d51, - float & d52, float & d53, float & d54, float & d55, - float & d56, float & d57, float & d58, float & d59, - float & d60, float & d61, float & d62, float & d63, - float & d64, float & d65, float & d66, float & d67, - float & d68, float & d69, float & d70, float & d71, - float & d72, float & d73, float & d74, float & d75, - float & d76, float & d77, float & d78, float & d79, - float & d80, float & d81, float & d82, float & d83, - float & d84, float & d85, float & d86, float & d87, - float & d88, float & d89, float & d90, float & d91, - float & d92, float & d93, float & d94, float & d95, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) @@ -49036,23 +19623,15 @@ struct MMA_64x192x32_F32E5M2E5M2_SS_TN asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %98, 0;\n" - "wgmma.mma_async.sync.aligned.m64n192k32.f32.e5m2.e5m2 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87, " - " %88, %89, %90, %91, %92, %93, %94, %95}," - " %96," - " %97," - " p, %99, %100;\n" + "setp.ne.b32 p, %34, 0;\n" + "wgmma.mma_async.sync.aligned.m64n64k32.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + " %32," + " %33," + " p, %35, %36;\n" "}\n" : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), @@ -49061,45 +19640,29 @@ struct MMA_64x192x32_F32E5M2E5M2_SS_TN "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), - "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), - "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), - "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), - "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), - "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), - "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), - "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), - "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), - "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), - "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), - "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), - "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), - "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), - "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), - "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), - "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91), - "+f"(d92), "+f"(d93), "+f"(d94), "+f"(d95) + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31) : "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x192x32_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x64x32_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// GMMA 64x192x32 TN F32+=E5M2*E5M2 +// GMMA 64x64x32 TN F32+=E5M2*E5M2 template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct MMA_64x192x32_F32E5M2E5M2_RS_TN +struct MMA_64x64x32_F32E5M2E5M2_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; using BRegisters = uint64_t[1]; - using CRegisters = float[96]; + using CRegisters = float[32]; CUTE_HOST_DEVICE static void fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, @@ -49112,22 +19675,6 @@ struct MMA_64x192x32_F32E5M2E5M2_RS_TN float & d20, float & d21, float & d22, float & d23, float & d24, float & d25, float & d26, float & d27, float & d28, float & d29, float & d30, float & d31, - float & d32, float & d33, float & d34, float & d35, - float & d36, float & d37, float & d38, float & d39, - float & d40, float & d41, float & d42, float & d43, - float & d44, float & d45, float & d46, float & d47, - float & d48, float & d49, float & d50, float & d51, - float & d52, float & d53, float & d54, float & d55, - float & d56, float & d57, float & d58, float & d59, - float & d60, float & d61, float & d62, float & d63, - float & d64, float & d65, float & d66, float & d67, - float & d68, float & d69, float & d70, float & d71, - float & d72, float & d73, float & d74, float & d75, - float & d76, float & d77, float & d78, float & d79, - float & d80, float & d81, float & d82, float & d83, - float & d84, float & d85, float & d86, float & d87, - float & d88, float & d89, float & d90, float & d91, - float & d92, float & d93, float & d94, float & d95, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) @@ -49135,23 +19682,15 @@ struct MMA_64x192x32_F32E5M2E5M2_RS_TN asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %101, 0;\n" - "wgmma.mma_async.sync.aligned.m64n192k32.f32.e5m2.e5m2 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87, " - " %88, %89, %90, %91, %92, %93, %94, %95}," - "{%96, %97, %98, %99}," - " %100," - " p, %102, %103;\n" + "setp.ne.b32 p, %37, 0;\n" + "wgmma.mma_async.sync.aligned.m64n64k32.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + "{%32, %33, %34, %35}," + " %36," + " p, %38, %39;\n" "}\n" : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), @@ -49160,46 +19699,29 @@ struct MMA_64x192x32_F32E5M2E5M2_RS_TN "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), - "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), - "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), - "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), - "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), - "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), - "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), - "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), - "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), - "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), - "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), - "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), - "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), - "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), - "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), - "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), - "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91), - "+f"(d92), "+f"(d93), "+f"(d94), "+f"(d95) + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31) : "r"(a00), "r"(a01), "r"(a02), "r"(a03), "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x192x32_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x64x32_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x208x32 TN F16+=E5M2*E5M2 +// GMMA 64x96x32 TN F16+=E5M2*E5M2 template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct MMA_64x208x32_F16E5M2E5M2_SS_TN +struct MMA_64x96x32_F16E5M2E5M2_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[52]; + using CRegisters = uint32_t[24]; CUTE_HOST_DEVICE static void fma(uint64_t const& desc_a, @@ -49210,13 +19732,6 @@ struct MMA_64x208x32_F16E5M2E5M2_SS_TN uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, - uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) @@ -49224,56 +19739,43 @@ struct MMA_64x208x32_F16E5M2E5M2_SS_TN asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %54, 0;\n" - "wgmma.mma_async.sync.aligned.m64n208k32.f16.e5m2.e5m2 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51}," - " %52," - " %53," - " p, %55, %56;\n" + "setp.ne.b32 p, %26, 0;\n" + "wgmma.mma_async.sync.aligned.m64n96k32.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + " %24," + " %25," + " p, %27, %28;\n" "}\n" : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), - "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), - "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51) + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) : "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x208x32_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x96x32_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x208x32 TN F16+=E5M2*E5M2 +// GMMA 64x96x32 TN F16+=E5M2*E5M2 template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct MMA_64x208x32_F16E5M2E5M2_RS_TN +struct MMA_64x96x32_F16E5M2E5M2_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[52]; + using CRegisters = uint32_t[24]; CUTE_HOST_DEVICE static void fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, @@ -49284,13 +19786,6 @@ struct MMA_64x208x32_F16E5M2E5M2_RS_TN uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, - uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) @@ -49298,86 +19793,59 @@ struct MMA_64x208x32_F16E5M2E5M2_RS_TN asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %57, 0;\n" - "wgmma.mma_async.sync.aligned.m64n208k32.f16.e5m2.e5m2 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51}," - "{%52, %53, %54, %55}," - " %56," - " p, %58, %59;\n" + "setp.ne.b32 p, %29, 0;\n" + "wgmma.mma_async.sync.aligned.m64n96k32.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + "{%24, %25, %26, %27}," + " %28," + " p, %30, %31;\n" "}\n" : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), - "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), - "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51) + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) : "r"(a00), "r"(a01), "r"(a02), "r"(a03), "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x208x32_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x96x32_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x208x32 TN F32+=E5M2*E5M2 +// GMMA 64x96x32 TN F32+=E5M2*E5M2 template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct MMA_64x208x32_F32E5M2E5M2_SS_TN +struct MMA_64x96x32_F32E5M2E5M2_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = float[104]; + using CRegisters = float[48]; CUTE_HOST_DEVICE static void fma(uint64_t const& desc_a, uint64_t const& desc_b, - float & d000, float & d001, float & d002, float & d003, - float & d004, float & d005, float & d006, float & d007, - float & d008, float & d009, float & d010, float & d011, - float & d012, float & d013, float & d014, float & d015, - float & d016, float & d017, float & d018, float & d019, - float & d020, float & d021, float & d022, float & d023, - float & d024, float & d025, float & d026, float & d027, - float & d028, float & d029, float & d030, float & d031, - float & d032, float & d033, float & d034, float & d035, - float & d036, float & d037, float & d038, float & d039, - float & d040, float & d041, float & d042, float & d043, - float & d044, float & d045, float & d046, float & d047, - float & d048, float & d049, float & d050, float & d051, - float & d052, float & d053, float & d054, float & d055, - float & d056, float & d057, float & d058, float & d059, - float & d060, float & d061, float & d062, float & d063, - float & d064, float & d065, float & d066, float & d067, - float & d068, float & d069, float & d070, float & d071, - float & d072, float & d073, float & d074, float & d075, - float & d076, float & d077, float & d078, float & d079, - float & d080, float & d081, float & d082, float & d083, - float & d084, float & d085, float & d086, float & d087, - float & d088, float & d089, float & d090, float & d091, - float & d092, float & d093, float & d094, float & d095, - float & d096, float & d097, float & d098, float & d099, - float & d100, float & d101, float & d102, float & d103, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) @@ -49385,105 +19853,68 @@ struct MMA_64x208x32_F32E5M2E5M2_SS_TN asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %106, 0;\n" - "wgmma.mma_async.sync.aligned.m64n208k32.f32.e5m2.e5m2 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87, " - " %88, %89, %90, %91, %92, %93, %94, %95, " - " %96, %97, %98, %99, %100, %101, %102, %103}," - " %104," - " %105," - " p, %107, %108;\n" + "setp.ne.b32 p, %50, 0;\n" + "wgmma.mma_async.sync.aligned.m64n96k32.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + " %48," + " %49," + " p, %51, %52;\n" "}\n" - : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), - "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), - "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), - "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), - "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), - "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), - "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), - "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), - "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), - "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), - "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), - "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), - "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), - "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), - "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), - "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), - "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), - "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), - "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), - "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), - "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), - "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), - "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), - "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), - "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), - "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103) + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47) : "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x208x32_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x96x32_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x208x32 TN F32+=E5M2*E5M2 +// GMMA 64x96x32 TN F32+=E5M2*E5M2 template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct MMA_64x208x32_F32E5M2E5M2_RS_TN +struct MMA_64x96x32_F32E5M2E5M2_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; using BRegisters = uint64_t[1]; - using CRegisters = float[104]; + using CRegisters = float[48]; CUTE_HOST_DEVICE static void - fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, uint64_t const& desc_b, - float & d000, float & d001, float & d002, float & d003, - float & d004, float & d005, float & d006, float & d007, - float & d008, float & d009, float & d010, float & d011, - float & d012, float & d013, float & d014, float & d015, - float & d016, float & d017, float & d018, float & d019, - float & d020, float & d021, float & d022, float & d023, - float & d024, float & d025, float & d026, float & d027, - float & d028, float & d029, float & d030, float & d031, - float & d032, float & d033, float & d034, float & d035, - float & d036, float & d037, float & d038, float & d039, - float & d040, float & d041, float & d042, float & d043, - float & d044, float & d045, float & d046, float & d047, - float & d048, float & d049, float & d050, float & d051, - float & d052, float & d053, float & d054, float & d055, - float & d056, float & d057, float & d058, float & d059, - float & d060, float & d061, float & d062, float & d063, - float & d064, float & d065, float & d066, float & d067, - float & d068, float & d069, float & d070, float & d071, - float & d072, float & d073, float & d074, float & d075, - float & d076, float & d077, float & d078, float & d079, - float & d080, float & d081, float & d082, float & d083, - float & d084, float & d085, float & d086, float & d087, - float & d088, float & d089, float & d090, float & d091, - float & d092, float & d093, float & d094, float & d095, - float & d096, float & d097, float & d098, float & d099, - float & d100, float & d101, float & d102, float & d103, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) @@ -49491,75 +19922,52 @@ struct MMA_64x208x32_F32E5M2E5M2_RS_TN asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %109, 0;\n" - "wgmma.mma_async.sync.aligned.m64n208k32.f32.e5m2.e5m2 " + "setp.ne.b32 p, %53, 0;\n" + "wgmma.mma_async.sync.aligned.m64n96k32.f32.e5m2.e5m2 " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " " %16, %17, %18, %19, %20, %21, %22, %23, " " %24, %25, %26, %27, %28, %29, %30, %31, " " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87, " - " %88, %89, %90, %91, %92, %93, %94, %95, " - " %96, %97, %98, %99, %100, %101, %102, %103}," - "{%104, %105, %106, %107}," - " %108," - " p, %110, %111;\n" + " %40, %41, %42, %43, %44, %45, %46, %47}," + "{%48, %49, %50, %51}," + " %52," + " p, %54, %55;\n" "}\n" - : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), - "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), - "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), - "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), - "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), - "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), - "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), - "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), - "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), - "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), - "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), - "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), - "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), - "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), - "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), - "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), - "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), - "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), - "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), - "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), - "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), - "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), - "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), - "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), - "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), - "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103) - : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x208x32_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x96x32_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x224x32 TN F16+=E5M2*E5M2 +// GMMA 64x128x32 TN F16+=E5M2*E5M2 template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct MMA_64x224x32_F16E5M2E5M2_SS_TN +struct MMA_64x128x32_F16E5M2E5M2_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[56]; + using CRegisters = uint32_t[32]; CUTE_HOST_DEVICE static void fma(uint64_t const& desc_a, @@ -49572,12 +19980,6 @@ struct MMA_64x224x32_F16E5M2E5M2_SS_TN uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, - uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, - uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) @@ -49585,18 +19987,15 @@ struct MMA_64x224x32_F16E5M2E5M2_SS_TN asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %58, 0;\n" - "wgmma.mma_async.sync.aligned.m64n224k32.f16.e5m2.e5m2 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55}," - " %56," - " %57," - " p, %59, %60;\n" + "setp.ne.b32 p, %34, 0;\n" + "wgmma.mma_async.sync.aligned.m64n128k32.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + " %32," + " %33," + " p, %35, %36;\n" "}\n" : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), @@ -49605,37 +20004,29 @@ struct MMA_64x224x32_F16E5M2E5M2_SS_TN "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), - "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), - "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), - "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) : "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x224x32_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x128x32_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x224x32 TN F16+=E5M2*E5M2 +// GMMA 64x128x32 TN F16+=E5M2*E5M2 template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct MMA_64x224x32_F16E5M2E5M2_RS_TN +struct MMA_64x128x32_F16E5M2E5M2_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[56]; + using CRegisters = uint32_t[32]; CUTE_HOST_DEVICE static void fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, @@ -49648,12 +20039,6 @@ struct MMA_64x224x32_F16E5M2E5M2_RS_TN uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, - uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, - uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) @@ -49661,18 +20046,15 @@ struct MMA_64x224x32_F16E5M2E5M2_RS_TN asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %61, 0;\n" - "wgmma.mma_async.sync.aligned.m64n224k32.f16.e5m2.e5m2 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55}," - "{%56, %57, %58, %59}," - " %60," - " p, %62, %63;\n" + "setp.ne.b32 p, %37, 0;\n" + "wgmma.mma_async.sync.aligned.m64n128k32.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + "{%32, %33, %34, %35}," + " %36," + " p, %38, %39;\n" "}\n" : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), @@ -49681,69 +20063,49 @@ struct MMA_64x224x32_F16E5M2E5M2_RS_TN "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), - "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), - "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), - "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) : "r"(a00), "r"(a01), "r"(a02), "r"(a03), - "l"(desc_b), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x224x32_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x224x32 TN F32+=E5M2*E5M2 -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -struct MMA_64x224x32_F32E5M2E5M2_SS_TN -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = float[112]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - float & d000, float & d001, float & d002, float & d003, - float & d004, float & d005, float & d006, float & d007, - float & d008, float & d009, float & d010, float & d011, - float & d012, float & d013, float & d014, float & d015, - float & d016, float & d017, float & d018, float & d019, - float & d020, float & d021, float & d022, float & d023, - float & d024, float & d025, float & d026, float & d027, - float & d028, float & d029, float & d030, float & d031, - float & d032, float & d033, float & d034, float & d035, - float & d036, float & d037, float & d038, float & d039, - float & d040, float & d041, float & d042, float & d043, - float & d044, float & d045, float & d046, float & d047, - float & d048, float & d049, float & d050, float & d051, - float & d052, float & d053, float & d054, float & d055, - float & d056, float & d057, float & d058, float & d059, - float & d060, float & d061, float & d062, float & d063, - float & d064, float & d065, float & d066, float & d067, - float & d068, float & d069, float & d070, float & d071, - float & d072, float & d073, float & d074, float & d075, - float & d076, float & d077, float & d078, float & d079, - float & d080, float & d081, float & d082, float & d083, - float & d084, float & d085, float & d086, float & d087, - float & d088, float & d089, float & d090, float & d091, - float & d092, float & d093, float & d094, float & d095, - float & d096, float & d097, float & d098, float & d099, - float & d100, float & d101, float & d102, float & d103, - float & d104, float & d105, float & d106, float & d107, - float & d108, float & d109, float & d110, float & d111, + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x128x32_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x128x32 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x128x32_F32E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[64]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) @@ -49751,8 +20113,8 @@ struct MMA_64x224x32_F32E5M2E5M2_SS_TN asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %114, 0;\n" - "wgmma.mma_async.sync.aligned.m64n224k32.f32.e5m2.e5m2 " + "setp.ne.b32 p, %66, 0;\n" + "wgmma.mma_async.sync.aligned.m64n128k32.f32.e5m2.e5m2 " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " " %16, %17, %18, %19, %20, %21, %22, %23, " @@ -49760,101 +20122,69 @@ struct MMA_64x224x32_F32E5M2E5M2_SS_TN " %32, %33, %34, %35, %36, %37, %38, %39, " " %40, %41, %42, %43, %44, %45, %46, %47, " " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87, " - " %88, %89, %90, %91, %92, %93, %94, %95, " - " %96, %97, %98, %99, %100, %101, %102, %103, " - " %104, %105, %106, %107, %108, %109, %110, %111}," - " %112," - " %113," - " p, %115, %116;\n" + " %56, %57, %58, %59, %60, %61, %62, %63}," + " %64," + " %65," + " p, %67, %68;\n" "}\n" - : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), - "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), - "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), - "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), - "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), - "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), - "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), - "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), - "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), - "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), - "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), - "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), - "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), - "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), - "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), - "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), - "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), - "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), - "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), - "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), - "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), - "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), - "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), - "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), - "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), - "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), - "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), - "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111) + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63) : "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x224x32_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x128x32_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x224x32 TN F32+=E5M2*E5M2 +// GMMA 64x128x32 TN F32+=E5M2*E5M2 template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct MMA_64x224x32_F32E5M2E5M2_RS_TN +struct MMA_64x128x32_F32E5M2E5M2_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; using BRegisters = uint64_t[1]; - using CRegisters = float[112]; + using CRegisters = float[64]; CUTE_HOST_DEVICE static void - fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, uint64_t const& desc_b, - float & d000, float & d001, float & d002, float & d003, - float & d004, float & d005, float & d006, float & d007, - float & d008, float & d009, float & d010, float & d011, - float & d012, float & d013, float & d014, float & d015, - float & d016, float & d017, float & d018, float & d019, - float & d020, float & d021, float & d022, float & d023, - float & d024, float & d025, float & d026, float & d027, - float & d028, float & d029, float & d030, float & d031, - float & d032, float & d033, float & d034, float & d035, - float & d036, float & d037, float & d038, float & d039, - float & d040, float & d041, float & d042, float & d043, - float & d044, float & d045, float & d046, float & d047, - float & d048, float & d049, float & d050, float & d051, - float & d052, float & d053, float & d054, float & d055, - float & d056, float & d057, float & d058, float & d059, - float & d060, float & d061, float & d062, float & d063, - float & d064, float & d065, float & d066, float & d067, - float & d068, float & d069, float & d070, float & d071, - float & d072, float & d073, float & d074, float & d075, - float & d076, float & d077, float & d078, float & d079, - float & d080, float & d081, float & d082, float & d083, - float & d084, float & d085, float & d086, float & d087, - float & d088, float & d089, float & d090, float & d091, - float & d092, float & d093, float & d094, float & d095, - float & d096, float & d097, float & d098, float & d099, - float & d100, float & d101, float & d102, float & d103, - float & d104, float & d105, float & d106, float & d107, - float & d108, float & d109, float & d110, float & d111, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) @@ -49862,8 +20192,8 @@ struct MMA_64x224x32_F32E5M2E5M2_RS_TN asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %117, 0;\n" - "wgmma.mma_async.sync.aligned.m64n224k32.f32.e5m2.e5m2 " + "setp.ne.b32 p, %69, 0;\n" + "wgmma.mma_async.sync.aligned.m64n128k32.f32.e5m2.e5m2 " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " " %16, %17, %18, %19, %20, %21, %22, %23, " @@ -49871,69 +20201,49 @@ struct MMA_64x224x32_F32E5M2E5M2_RS_TN " %32, %33, %34, %35, %36, %37, %38, %39, " " %40, %41, %42, %43, %44, %45, %46, %47, " " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87, " - " %88, %89, %90, %91, %92, %93, %94, %95, " - " %96, %97, %98, %99, %100, %101, %102, %103, " - " %104, %105, %106, %107, %108, %109, %110, %111}," - "{%112, %113, %114, %115}," - " %116," - " p, %118, %119;\n" + " %56, %57, %58, %59, %60, %61, %62, %63}," + "{%64, %65, %66, %67}," + " %68," + " p, %70, %71;\n" "}\n" - : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), - "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), - "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), - "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), - "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), - "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), - "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), - "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), - "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), - "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), - "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), - "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), - "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), - "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), - "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), - "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), - "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), - "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), - "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), - "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), - "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), - "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), - "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), - "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), - "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), - "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), - "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), - "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111) - : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x224x32_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x128x32_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x240x32 TN F16+=E5M2*E5M2 +// GMMA 64x192x32 TN F16+=E5M2*E5M2 template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct MMA_64x240x32_F16E5M2E5M2_SS_TN +struct MMA_64x192x32_F16E5M2E5M2_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[60]; + using CRegisters = uint32_t[48]; CUTE_HOST_DEVICE static void fma(uint64_t const& desc_a, @@ -49950,9 +20260,6 @@ struct MMA_64x240x32_F16E5M2E5M2_SS_TN uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, - uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, - uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, - uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) @@ -49960,19 +20267,17 @@ struct MMA_64x240x32_F16E5M2E5M2_SS_TN asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %62, 0;\n" - "wgmma.mma_async.sync.aligned.m64n240k32.f16.e5m2.e5m2 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59}," - " %60," - " %61," - " p, %63, %64;\n" + "setp.ne.b32 p, %50, 0;\n" + "wgmma.mma_async.sync.aligned.m64n192k32.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + " %48," + " %49," + " p, %51, %52;\n" "}\n" : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), @@ -49985,34 +20290,29 @@ struct MMA_64x240x32_F16E5M2E5M2_SS_TN "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), - "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), - "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), - "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), - "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59) + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) : "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x240x32_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x192x32_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x240x32 TN F16+=E5M2*E5M2 +// GMMA 64x192x32 TN F16+=E5M2*E5M2 template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct MMA_64x240x32_F16E5M2E5M2_RS_TN +struct MMA_64x192x32_F16E5M2E5M2_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[60]; + using CRegisters = uint32_t[48]; CUTE_HOST_DEVICE static void fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, @@ -50029,9 +20329,6 @@ struct MMA_64x240x32_F16E5M2E5M2_RS_TN uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, - uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, - uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, - uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) @@ -50039,19 +20336,17 @@ struct MMA_64x240x32_F16E5M2E5M2_RS_TN asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %65, 0;\n" - "wgmma.mma_async.sync.aligned.m64n240k32.f16.e5m2.e5m2 " + "setp.ne.b32 p, %53, 0;\n" + "wgmma.mma_async.sync.aligned.m64n192k32.f16.e5m2.e5m2 " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " " %16, %17, %18, %19, %20, %21, %22, %23, " " %24, %25, %26, %27, %28, %29, %30, %31, " " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59}," - "{%60, %61, %62, %63}," - " %64," - " p, %66, %67;\n" + " %40, %41, %42, %43, %44, %45, %46, %47}," + "{%48, %49, %50, %51}," + " %52," + " p, %54, %55;\n" "}\n" : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), @@ -50064,68 +20359,57 @@ struct MMA_64x240x32_F16E5M2E5M2_RS_TN "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), - "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), - "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), - "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), - "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59) + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) : "r"(a00), "r"(a01), "r"(a02), "r"(a03), "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x240x32_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x192x32_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x240x32 TN F32+=E5M2*E5M2 +// GMMA 64x192x32 TN F32+=E5M2*E5M2 template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct MMA_64x240x32_F32E5M2E5M2_SS_TN +struct MMA_64x192x32_F32E5M2E5M2_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = float[120]; + using CRegisters = float[96]; CUTE_HOST_DEVICE static void fma(uint64_t const& desc_a, uint64_t const& desc_b, - float & d000, float & d001, float & d002, float & d003, - float & d004, float & d005, float & d006, float & d007, - float & d008, float & d009, float & d010, float & d011, - float & d012, float & d013, float & d014, float & d015, - float & d016, float & d017, float & d018, float & d019, - float & d020, float & d021, float & d022, float & d023, - float & d024, float & d025, float & d026, float & d027, - float & d028, float & d029, float & d030, float & d031, - float & d032, float & d033, float & d034, float & d035, - float & d036, float & d037, float & d038, float & d039, - float & d040, float & d041, float & d042, float & d043, - float & d044, float & d045, float & d046, float & d047, - float & d048, float & d049, float & d050, float & d051, - float & d052, float & d053, float & d054, float & d055, - float & d056, float & d057, float & d058, float & d059, - float & d060, float & d061, float & d062, float & d063, - float & d064, float & d065, float & d066, float & d067, - float & d068, float & d069, float & d070, float & d071, - float & d072, float & d073, float & d074, float & d075, - float & d076, float & d077, float & d078, float & d079, - float & d080, float & d081, float & d082, float & d083, - float & d084, float & d085, float & d086, float & d087, - float & d088, float & d089, float & d090, float & d091, - float & d092, float & d093, float & d094, float & d095, - float & d096, float & d097, float & d098, float & d099, - float & d100, float & d101, float & d102, float & d103, - float & d104, float & d105, float & d106, float & d107, - float & d108, float & d109, float & d110, float & d111, - float & d112, float & d113, float & d114, float & d115, - float & d116, float & d117, float & d118, float & d119, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + float & d88, float & d89, float & d90, float & d91, + float & d92, float & d93, float & d94, float & d95, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) @@ -50133,8 +20417,8 @@ struct MMA_64x240x32_F32E5M2E5M2_SS_TN asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %122, 0;\n" - "wgmma.mma_async.sync.aligned.m64n240k32.f32.e5m2.e5m2 " + "setp.ne.b32 p, %98, 0;\n" + "wgmma.mma_async.sync.aligned.m64n192k32.f32.e5m2.e5m2 " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " " %16, %17, %18, %19, %20, %21, %22, %23, " @@ -50146,102 +20430,85 @@ struct MMA_64x240x32_F32E5M2E5M2_SS_TN " %64, %65, %66, %67, %68, %69, %70, %71, " " %72, %73, %74, %75, %76, %77, %78, %79, " " %80, %81, %82, %83, %84, %85, %86, %87, " - " %88, %89, %90, %91, %92, %93, %94, %95, " - " %96, %97, %98, %99, %100, %101, %102, %103, " - " %104, %105, %106, %107, %108, %109, %110, %111, " - " %112, %113, %114, %115, %116, %117, %118, %119}," - " %120," - " %121," - " p, %123, %124;\n" + " %88, %89, %90, %91, %92, %93, %94, %95}," + " %96," + " %97," + " p, %99, %100;\n" "}\n" - : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), - "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), - "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), - "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), - "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), - "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), - "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), - "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), - "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), - "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), - "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), - "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), - "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), - "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), - "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), - "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), - "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), - "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), - "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), - "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), - "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), - "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), - "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), - "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), - "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), - "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), - "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), - "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), - "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), - "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119) + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), + "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91), + "+f"(d92), "+f"(d93), "+f"(d94), "+f"(d95) : "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x240x32_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x192x32_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// GMMA 64x240x32 TN F32+=E5M2*E5M2 +// GMMA 64x192x32 TN F32+=E5M2*E5M2 template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -struct MMA_64x240x32_F32E5M2E5M2_RS_TN +struct MMA_64x192x32_F32E5M2E5M2_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; using BRegisters = uint64_t[1]; - using CRegisters = float[120]; + using CRegisters = float[96]; CUTE_HOST_DEVICE static void - fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, uint64_t const& desc_b, - float & d000, float & d001, float & d002, float & d003, - float & d004, float & d005, float & d006, float & d007, - float & d008, float & d009, float & d010, float & d011, - float & d012, float & d013, float & d014, float & d015, - float & d016, float & d017, float & d018, float & d019, - float & d020, float & d021, float & d022, float & d023, - float & d024, float & d025, float & d026, float & d027, - float & d028, float & d029, float & d030, float & d031, - float & d032, float & d033, float & d034, float & d035, - float & d036, float & d037, float & d038, float & d039, - float & d040, float & d041, float & d042, float & d043, - float & d044, float & d045, float & d046, float & d047, - float & d048, float & d049, float & d050, float & d051, - float & d052, float & d053, float & d054, float & d055, - float & d056, float & d057, float & d058, float & d059, - float & d060, float & d061, float & d062, float & d063, - float & d064, float & d065, float & d066, float & d067, - float & d068, float & d069, float & d070, float & d071, - float & d072, float & d073, float & d074, float & d075, - float & d076, float & d077, float & d078, float & d079, - float & d080, float & d081, float & d082, float & d083, - float & d084, float & d085, float & d086, float & d087, - float & d088, float & d089, float & d090, float & d091, - float & d092, float & d093, float & d094, float & d095, - float & d096, float & d097, float & d098, float & d099, - float & d100, float & d101, float & d102, float & d103, - float & d104, float & d105, float & d106, float & d107, - float & d108, float & d109, float & d110, float & d111, - float & d112, float & d113, float & d114, float & d115, - float & d116, float & d117, float & d118, float & d119, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + float & d88, float & d89, float & d90, float & d91, + float & d92, float & d93, float & d94, float & d95, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) @@ -50249,8 +20516,8 @@ struct MMA_64x240x32_F32E5M2E5M2_RS_TN asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %125, 0;\n" - "wgmma.mma_async.sync.aligned.m64n240k32.f32.e5m2.e5m2 " + "setp.ne.b32 p, %101, 0;\n" + "wgmma.mma_async.sync.aligned.m64n192k32.f32.e5m2.e5m2 " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " " %16, %17, %18, %19, %20, %21, %22, %23, " @@ -50262,53 +20529,43 @@ struct MMA_64x240x32_F32E5M2E5M2_RS_TN " %64, %65, %66, %67, %68, %69, %70, %71, " " %72, %73, %74, %75, %76, %77, %78, %79, " " %80, %81, %82, %83, %84, %85, %86, %87, " - " %88, %89, %90, %91, %92, %93, %94, %95, " - " %96, %97, %98, %99, %100, %101, %102, %103, " - " %104, %105, %106, %107, %108, %109, %110, %111, " - " %112, %113, %114, %115, %116, %117, %118, %119}," - "{%120, %121, %122, %123}," - " %124," - " p, %126, %127;\n" + " %88, %89, %90, %91, %92, %93, %94, %95}," + "{%96, %97, %98, %99}," + " %100," + " p, %102, %103;\n" "}\n" - : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), - "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), - "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), - "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), - "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), - "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), - "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), - "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), - "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), - "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), - "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), - "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), - "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), - "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), - "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), - "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), - "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), - "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), - "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), - "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), - "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), - "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), - "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), - "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), - "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), - "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), - "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), - "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), - "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), - "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119) - : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), + "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91), + "+f"(d92), "+f"(d93), "+f"(d94), "+f"(d95) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), "l"(desc_b), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x240x32_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x192x32_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -50708,7 +20965,10 @@ struct MMA_64x256x32_F32E5M2E5M2_RS_TN //////////////////////////////////////////////////////////////////////////////////////////////////// - } // namespace SM90::GMMA } // namespace cute + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +#include "mma_sm90_gmma_ext.hpp" +#endif diff --git a/include/cute/arch/mma_sm90_gmma_ext.hpp b/include/cute/arch/mma_sm90_gmma_ext.hpp new file mode 100644 index 000000000..10a36aff8 --- /dev/null +++ b/include/cute/arch/mma_sm90_gmma_ext.hpp @@ -0,0 +1,56445 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include // CUTE_HOST_DEVICE + +#include "cutlass/arch/synclog.hpp" + +// Config +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) && defined(__CUDA_ARCH_FEAT_SM90_ALL)) +# define CUTE_ARCH_MMA_SM90A_ENABLED +#endif + +namespace cute { + +namespace SM90::GMMA { + +// GMMA 64x24x16 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x24x16_F16F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[6]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %8, 0;\n" + "wgmma.mma_async.sync.aligned.m64n24k16.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5}," + " %6," + " %7," + " p, %9, %10, %11, %12;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x24x16_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x24x16 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x24x16_F16F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[6]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %11, 0;\n" + "wgmma.mma_async.sync.aligned.m64n24k16.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5}," + "{%6, %7, %8, %9}," + " %10," + " p, %12, %13, %14;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x24x16_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x40x16 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x40x16_F16F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[10]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %12, 0;\n" + "wgmma.mma_async.sync.aligned.m64n40k16.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9}," + " %10," + " %11," + " p, %13, %14, %15, %16;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x40x16_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x40x16 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x40x16_F16F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[10]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %15, 0;\n" + "wgmma.mma_async.sync.aligned.m64n40k16.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9}," + "{%10, %11, %12, %13}," + " %14," + " p, %16, %17, %18;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x40x16_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x48x16 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x48x16_F16F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[12]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %14, 0;\n" + "wgmma.mma_async.sync.aligned.m64n48k16.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + " %12," + " %13," + " p, %15, %16, %17, %18;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x48x16_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x48x16 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x48x16_F16F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[12]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %17, 0;\n" + "wgmma.mma_async.sync.aligned.m64n48k16.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + "{%12, %13, %14, %15}," + " %16," + " p, %18, %19, %20;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x48x16_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x56x16 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x56x16_F16F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[14]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %16, 0;\n" + "wgmma.mma_async.sync.aligned.m64n56k16.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13}," + " %14," + " %15," + " p, %17, %18, %19, %20;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x56x16_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x56x16 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x56x16_F16F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[14]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %19, 0;\n" + "wgmma.mma_async.sync.aligned.m64n56k16.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13}," + "{%14, %15, %16, %17}," + " %18," + " p, %20, %21, %22;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x56x16_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x72x16 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x72x16_F16F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[18]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %20, 0;\n" + "wgmma.mma_async.sync.aligned.m64n72k16.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17}," + " %18," + " %19," + " p, %21, %22, %23, %24;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x72x16_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x72x16 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x72x16_F16F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[18]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %23, 0;\n" + "wgmma.mma_async.sync.aligned.m64n72k16.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17}," + "{%18, %19, %20, %21}," + " %22," + " p, %24, %25, %26;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x72x16_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x80x16 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x80x16_F16F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[20]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %22, 0;\n" + "wgmma.mma_async.sync.aligned.m64n80k16.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19}," + " %20," + " %21," + " p, %23, %24, %25, %26;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x80x16_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x80x16 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x80x16_F16F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[20]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %25, 0;\n" + "wgmma.mma_async.sync.aligned.m64n80k16.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19}," + "{%20, %21, %22, %23}," + " %24," + " p, %26, %27, %28;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x80x16_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x88x16 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x88x16_F16F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[22]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %24, 0;\n" + "wgmma.mma_async.sync.aligned.m64n88k16.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21}," + " %22," + " %23," + " p, %25, %26, %27, %28;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x88x16_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x88x16 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x88x16_F16F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[22]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %27, 0;\n" + "wgmma.mma_async.sync.aligned.m64n88k16.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21}," + "{%22, %23, %24, %25}," + " %26," + " p, %28, %29, %30;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x88x16_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x104x16 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x104x16_F16F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[26]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %28, 0;\n" + "wgmma.mma_async.sync.aligned.m64n104k16.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25}," + " %26," + " %27," + " p, %29, %30, %31, %32;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x104x16_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x104x16 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x104x16_F16F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[26]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %31, 0;\n" + "wgmma.mma_async.sync.aligned.m64n104k16.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25}," + "{%26, %27, %28, %29}," + " %30," + " p, %32, %33, %34;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x104x16_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x112x16 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x112x16_F16F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[28]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %30, 0;\n" + "wgmma.mma_async.sync.aligned.m64n112k16.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27}," + " %28," + " %29," + " p, %31, %32, %33, %34;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x112x16_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x112x16 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x112x16_F16F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[28]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %33, 0;\n" + "wgmma.mma_async.sync.aligned.m64n112k16.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27}," + "{%28, %29, %30, %31}," + " %32," + " p, %34, %35, %36;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x112x16_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x120x16 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x120x16_F16F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[30]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %32, 0;\n" + "wgmma.mma_async.sync.aligned.m64n120k16.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29}," + " %30," + " %31," + " p, %33, %34, %35, %36;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x120x16_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x120x16 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x120x16_F16F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[30]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %35, 0;\n" + "wgmma.mma_async.sync.aligned.m64n120k16.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29}," + "{%30, %31, %32, %33}," + " %34," + " p, %36, %37, %38;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x120x16_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x136x16 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x136x16_F16F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[34]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %36, 0;\n" + "wgmma.mma_async.sync.aligned.m64n136k16.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33}," + " %34," + " %35," + " p, %37, %38, %39, %40;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x136x16_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x136x16 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x136x16_F16F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[34]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %39, 0;\n" + "wgmma.mma_async.sync.aligned.m64n136k16.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33}," + "{%34, %35, %36, %37}," + " %38," + " p, %40, %41, %42;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x136x16_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x144x16 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x144x16_F16F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[36]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %38, 0;\n" + "wgmma.mma_async.sync.aligned.m64n144k16.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35}," + " %36," + " %37," + " p, %39, %40, %41, %42;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x144x16_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x144x16 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x144x16_F16F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[36]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %41, 0;\n" + "wgmma.mma_async.sync.aligned.m64n144k16.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35}," + "{%36, %37, %38, %39}," + " %40," + " p, %42, %43, %44;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x144x16_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x152x16 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x152x16_F16F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[38]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %40, 0;\n" + "wgmma.mma_async.sync.aligned.m64n152k16.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37}," + " %38," + " %39," + " p, %41, %42, %43, %44;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x152x16_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x152x16 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x152x16_F16F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[38]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %43, 0;\n" + "wgmma.mma_async.sync.aligned.m64n152k16.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37}," + "{%38, %39, %40, %41}," + " %42," + " p, %44, %45, %46;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x152x16_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x160x16 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x160x16_F16F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[40]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %42, 0;\n" + "wgmma.mma_async.sync.aligned.m64n160k16.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + " %40," + " %41," + " p, %43, %44, %45, %46;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x160x16_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x160x16 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x160x16_F16F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[40]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %45, 0;\n" + "wgmma.mma_async.sync.aligned.m64n160k16.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + "{%40, %41, %42, %43}," + " %44," + " p, %46, %47, %48;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x160x16_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x168x16 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x168x16_F16F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[42]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %44, 0;\n" + "wgmma.mma_async.sync.aligned.m64n168k16.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41}," + " %42," + " %43," + " p, %45, %46, %47, %48;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x168x16_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x168x16 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x168x16_F16F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[42]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %47, 0;\n" + "wgmma.mma_async.sync.aligned.m64n168k16.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41}," + "{%42, %43, %44, %45}," + " %46," + " p, %48, %49, %50;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x168x16_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x176x16 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x176x16_F16F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[44]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %46, 0;\n" + "wgmma.mma_async.sync.aligned.m64n176k16.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43}," + " %44," + " %45," + " p, %47, %48, %49, %50;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x176x16_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x176x16 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x176x16_F16F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[44]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %49, 0;\n" + "wgmma.mma_async.sync.aligned.m64n176k16.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43}," + "{%44, %45, %46, %47}," + " %48," + " p, %50, %51, %52;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x176x16_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x184x16 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x184x16_F16F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[46]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %48, 0;\n" + "wgmma.mma_async.sync.aligned.m64n184k16.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45}," + " %46," + " %47," + " p, %49, %50, %51, %52;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x184x16_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x184x16 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x184x16_F16F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[46]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %51, 0;\n" + "wgmma.mma_async.sync.aligned.m64n184k16.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45}," + "{%46, %47, %48, %49}," + " %50," + " p, %52, %53, %54;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x184x16_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x200x16 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x200x16_F16F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[50]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %52, 0;\n" + "wgmma.mma_async.sync.aligned.m64n200k16.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49}," + " %50," + " %51," + " p, %53, %54, %55, %56;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x200x16_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x200x16 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x200x16_F16F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[50]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %55, 0;\n" + "wgmma.mma_async.sync.aligned.m64n200k16.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49}," + "{%50, %51, %52, %53}," + " %54," + " p, %56, %57, %58;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x200x16_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x208x16 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x208x16_F16F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[52]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %54, 0;\n" + "wgmma.mma_async.sync.aligned.m64n208k16.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51}," + " %52," + " %53," + " p, %55, %56, %57, %58;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x208x16_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x208x16 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x208x16_F16F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[52]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %57, 0;\n" + "wgmma.mma_async.sync.aligned.m64n208k16.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51}," + "{%52, %53, %54, %55}," + " %56," + " p, %58, %59, %60;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x208x16_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x216x16 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x216x16_F16F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[54]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %56, 0;\n" + "wgmma.mma_async.sync.aligned.m64n216k16.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53}," + " %54," + " %55," + " p, %57, %58, %59, %60;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x216x16_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x216x16 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x216x16_F16F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[54]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %59, 0;\n" + "wgmma.mma_async.sync.aligned.m64n216k16.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53}," + "{%54, %55, %56, %57}," + " %58," + " p, %60, %61, %62;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x216x16_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x224x16 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x224x16_F16F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[56]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %58, 0;\n" + "wgmma.mma_async.sync.aligned.m64n224k16.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + " %56," + " %57," + " p, %59, %60, %61, %62;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x224x16_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x224x16 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x224x16_F16F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[56]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %61, 0;\n" + "wgmma.mma_async.sync.aligned.m64n224k16.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + "{%56, %57, %58, %59}," + " %60," + " p, %62, %63, %64;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x224x16_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x232x16 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x232x16_F16F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[58]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %60, 0;\n" + "wgmma.mma_async.sync.aligned.m64n232k16.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57}," + " %58," + " %59," + " p, %61, %62, %63, %64;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x232x16_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x232x16 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x232x16_F16F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[58]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %63, 0;\n" + "wgmma.mma_async.sync.aligned.m64n232k16.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57}," + "{%58, %59, %60, %61}," + " %62," + " p, %64, %65, %66;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x232x16_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x240x16 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x240x16_F16F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[60]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %62, 0;\n" + "wgmma.mma_async.sync.aligned.m64n240k16.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59}," + " %60," + " %61," + " p, %63, %64, %65, %66;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x240x16_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x240x16 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x240x16_F16F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[60]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %65, 0;\n" + "wgmma.mma_async.sync.aligned.m64n240k16.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59}," + "{%60, %61, %62, %63}," + " %64," + " p, %66, %67, %68;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x240x16_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x248x16 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x248x16_F16F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[62]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %64, 0;\n" + "wgmma.mma_async.sync.aligned.m64n248k16.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61}," + " %62," + " %63," + " p, %65, %66, %67, %68;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x248x16_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x248x16 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x248x16_F16F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[62]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %67, 0;\n" + "wgmma.mma_async.sync.aligned.m64n248k16.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61}," + "{%62, %63, %64, %65}," + " %66," + " p, %68, %69, %70;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x248x16_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x24x16 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x24x16_F32F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[12]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %14, 0;\n" + "wgmma.mma_async.sync.aligned.m64n24k16.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + " %12," + " %13," + " p, %15, %16, %17, %18;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x24x16_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x24x16 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x24x16_F32F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[12]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %17, 0;\n" + "wgmma.mma_async.sync.aligned.m64n24k16.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + "{%12, %13, %14, %15}," + " %16," + " p, %18, %19, %20;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x24x16_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x40x16 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x40x16_F32F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[20]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %22, 0;\n" + "wgmma.mma_async.sync.aligned.m64n40k16.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19}," + " %20," + " %21," + " p, %23, %24, %25, %26;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x40x16_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x40x16 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x40x16_F32F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[20]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %25, 0;\n" + "wgmma.mma_async.sync.aligned.m64n40k16.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19}," + "{%20, %21, %22, %23}," + " %24," + " p, %26, %27, %28;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x40x16_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x48x16 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x48x16_F32F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[24]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %26, 0;\n" + "wgmma.mma_async.sync.aligned.m64n48k16.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + " %24," + " %25," + " p, %27, %28, %29, %30;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x48x16_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x48x16 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x48x16_F32F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[24]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %29, 0;\n" + "wgmma.mma_async.sync.aligned.m64n48k16.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + "{%24, %25, %26, %27}," + " %28," + " p, %30, %31, %32;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x48x16_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x56x16 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x56x16_F32F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[28]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %30, 0;\n" + "wgmma.mma_async.sync.aligned.m64n56k16.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27}," + " %28," + " %29," + " p, %31, %32, %33, %34;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x56x16_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x56x16 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x56x16_F32F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[28]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %33, 0;\n" + "wgmma.mma_async.sync.aligned.m64n56k16.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27}," + "{%28, %29, %30, %31}," + " %32," + " p, %34, %35, %36;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x56x16_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x72x16 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x72x16_F32F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[36]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %38, 0;\n" + "wgmma.mma_async.sync.aligned.m64n72k16.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35}," + " %36," + " %37," + " p, %39, %40, %41, %42;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x72x16_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x72x16 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x72x16_F32F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[36]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %41, 0;\n" + "wgmma.mma_async.sync.aligned.m64n72k16.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35}," + "{%36, %37, %38, %39}," + " %40," + " p, %42, %43, %44;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x72x16_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x80x16 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x80x16_F32F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[40]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %42, 0;\n" + "wgmma.mma_async.sync.aligned.m64n80k16.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + " %40," + " %41," + " p, %43, %44, %45, %46;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x80x16_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x80x16 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x80x16_F32F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[40]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %45, 0;\n" + "wgmma.mma_async.sync.aligned.m64n80k16.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + "{%40, %41, %42, %43}," + " %44," + " p, %46, %47, %48;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x80x16_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x88x16 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x88x16_F32F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[44]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %46, 0;\n" + "wgmma.mma_async.sync.aligned.m64n88k16.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43}," + " %44," + " %45," + " p, %47, %48, %49, %50;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x88x16_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x88x16 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x88x16_F32F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[44]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %49, 0;\n" + "wgmma.mma_async.sync.aligned.m64n88k16.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43}," + "{%44, %45, %46, %47}," + " %48," + " p, %50, %51, %52;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x88x16_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x104x16 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x104x16_F32F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[52]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %54, 0;\n" + "wgmma.mma_async.sync.aligned.m64n104k16.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51}," + " %52," + " %53," + " p, %55, %56, %57, %58;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x104x16_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x104x16 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x104x16_F32F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[52]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %57, 0;\n" + "wgmma.mma_async.sync.aligned.m64n104k16.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51}," + "{%52, %53, %54, %55}," + " %56," + " p, %58, %59, %60;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x104x16_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x112x16 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x112x16_F32F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[56]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %58, 0;\n" + "wgmma.mma_async.sync.aligned.m64n112k16.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + " %56," + " %57," + " p, %59, %60, %61, %62;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x112x16_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x112x16 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x112x16_F32F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[56]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %61, 0;\n" + "wgmma.mma_async.sync.aligned.m64n112k16.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + "{%56, %57, %58, %59}," + " %60," + " p, %62, %63, %64;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x112x16_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x120x16 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x120x16_F32F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[60]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %62, 0;\n" + "wgmma.mma_async.sync.aligned.m64n120k16.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59}," + " %60," + " %61," + " p, %63, %64, %65, %66;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x120x16_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x120x16 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x120x16_F32F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[60]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %65, 0;\n" + "wgmma.mma_async.sync.aligned.m64n120k16.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59}," + "{%60, %61, %62, %63}," + " %64," + " p, %66, %67, %68;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x120x16_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x136x16 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x136x16_F32F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[68]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %70, 0;\n" + "wgmma.mma_async.sync.aligned.m64n136k16.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67}," + " %68," + " %69," + " p, %71, %72, %73, %74;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x136x16_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x136x16 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x136x16_F32F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[68]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %73, 0;\n" + "wgmma.mma_async.sync.aligned.m64n136k16.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67}," + "{%68, %69, %70, %71}," + " %72," + " p, %74, %75, %76;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x136x16_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x144x16 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x144x16_F32F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[72]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %74, 0;\n" + "wgmma.mma_async.sync.aligned.m64n144k16.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}," + " %72," + " %73," + " p, %75, %76, %77, %78;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x144x16_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x144x16 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x144x16_F32F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[72]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %77, 0;\n" + "wgmma.mma_async.sync.aligned.m64n144k16.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}," + "{%72, %73, %74, %75}," + " %76," + " p, %78, %79, %80;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x144x16_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x152x16 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x152x16_F32F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[76]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %78, 0;\n" + "wgmma.mma_async.sync.aligned.m64n152k16.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75}," + " %76," + " %77," + " p, %79, %80, %81, %82;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x152x16_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x152x16 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x152x16_F32F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[76]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %81, 0;\n" + "wgmma.mma_async.sync.aligned.m64n152k16.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75}," + "{%76, %77, %78, %79}," + " %80," + " p, %82, %83, %84;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x152x16_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x160x16 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x160x16_F32F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[80]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %82, 0;\n" + "wgmma.mma_async.sync.aligned.m64n160k16.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}," + " %80," + " %81," + " p, %83, %84, %85, %86;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x160x16_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x160x16 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x160x16_F32F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[80]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %85, 0;\n" + "wgmma.mma_async.sync.aligned.m64n160k16.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}," + "{%80, %81, %82, %83}," + " %84," + " p, %86, %87, %88;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x160x16_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x168x16 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x168x16_F32F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[84]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %86, 0;\n" + "wgmma.mma_async.sync.aligned.m64n168k16.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83}," + " %84," + " %85," + " p, %87, %88, %89, %90;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x168x16_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x168x16 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x168x16_F32F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[84]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %89, 0;\n" + "wgmma.mma_async.sync.aligned.m64n168k16.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83}," + "{%84, %85, %86, %87}," + " %88," + " p, %90, %91, %92;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x168x16_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x176x16 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x176x16_F32F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[88]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %90, 0;\n" + "wgmma.mma_async.sync.aligned.m64n176k16.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87}," + " %88," + " %89," + " p, %91, %92, %93, %94;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x176x16_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x176x16 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x176x16_F32F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[88]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %93, 0;\n" + "wgmma.mma_async.sync.aligned.m64n176k16.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87}," + "{%88, %89, %90, %91}," + " %92," + " p, %94, %95, %96;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x176x16_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x184x16 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x184x16_F32F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[92]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + float & d88, float & d89, float & d90, float & d91, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %94, 0;\n" + "wgmma.mma_async.sync.aligned.m64n184k16.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91}," + " %92," + " %93," + " p, %95, %96, %97, %98;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), + "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x184x16_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x184x16 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x184x16_F32F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[92]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + float & d88, float & d89, float & d90, float & d91, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %97, 0;\n" + "wgmma.mma_async.sync.aligned.m64n184k16.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91}," + "{%92, %93, %94, %95}," + " %96," + " p, %98, %99, %100;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), + "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x184x16_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x200x16 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x200x16_F32F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[100]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %102, 0;\n" + "wgmma.mma_async.sync.aligned.m64n200k16.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99}," + " %100," + " %101," + " p, %103, %104, %105, %106;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x200x16_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x200x16 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x200x16_F32F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[100]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %105, 0;\n" + "wgmma.mma_async.sync.aligned.m64n200k16.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99}," + "{%100, %101, %102, %103}," + " %104," + " p, %106, %107, %108;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x200x16_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x208x16 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x208x16_F32F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[104]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %106, 0;\n" + "wgmma.mma_async.sync.aligned.m64n208k16.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103}," + " %104," + " %105," + " p, %107, %108, %109, %110;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x208x16_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x208x16 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x208x16_F32F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[104]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %109, 0;\n" + "wgmma.mma_async.sync.aligned.m64n208k16.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103}," + "{%104, %105, %106, %107}," + " %108," + " p, %110, %111, %112;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x208x16_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x216x16 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x216x16_F32F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[108]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %110, 0;\n" + "wgmma.mma_async.sync.aligned.m64n216k16.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107}," + " %108," + " %109," + " p, %111, %112, %113, %114;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x216x16_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x216x16 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x216x16_F32F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[108]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %113, 0;\n" + "wgmma.mma_async.sync.aligned.m64n216k16.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107}," + "{%108, %109, %110, %111}," + " %112," + " p, %114, %115, %116;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x216x16_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x224x16 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x224x16_F32F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[112]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %114, 0;\n" + "wgmma.mma_async.sync.aligned.m64n224k16.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111}," + " %112," + " %113," + " p, %115, %116, %117, %118;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x224x16_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x224x16 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x224x16_F32F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[112]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %117, 0;\n" + "wgmma.mma_async.sync.aligned.m64n224k16.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111}," + "{%112, %113, %114, %115}," + " %116," + " p, %118, %119, %120;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x224x16_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x232x16 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x232x16_F32F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[116]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %118, 0;\n" + "wgmma.mma_async.sync.aligned.m64n232k16.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115}," + " %116," + " %117," + " p, %119, %120, %121, %122;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x232x16_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x232x16 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x232x16_F32F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[116]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %121, 0;\n" + "wgmma.mma_async.sync.aligned.m64n232k16.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115}," + "{%116, %117, %118, %119}," + " %120," + " p, %122, %123, %124;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x232x16_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x240x16 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x240x16_F32F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[120]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %122, 0;\n" + "wgmma.mma_async.sync.aligned.m64n240k16.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119}," + " %120," + " %121," + " p, %123, %124, %125, %126;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x240x16_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x240x16 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x240x16_F32F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[120]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %125, 0;\n" + "wgmma.mma_async.sync.aligned.m64n240k16.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119}," + "{%120, %121, %122, %123}," + " %124," + " p, %126, %127, %128;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x240x16_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x248x16 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x248x16_F32F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[124]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + float & d120, float & d121, float & d122, float & d123, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %126, 0;\n" + "wgmma.mma_async.sync.aligned.m64n248k16.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123}," + " %124," + " %125," + " p, %127, %128, %129, %130;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), + "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x248x16_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x248x16 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x248x16_F32F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[124]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + float & d120, float & d121, float & d122, float & d123, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %129, 0;\n" + "wgmma.mma_async.sync.aligned.m64n248k16.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123}," + "{%124, %125, %126, %127}," + " %128," + " p, %130, %131, %132;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), + "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x248x16_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x24x16 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x24x16_F32BF16BF16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[12]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %14, 0;\n" + "wgmma.mma_async.sync.aligned.m64n24k16.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + " %12," + " %13," + " p, %15, %16, %17, %18;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x24x16_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x24x16 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x24x16_F32BF16BF16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[12]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %17, 0;\n" + "wgmma.mma_async.sync.aligned.m64n24k16.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + "{%12, %13, %14, %15}," + " %16," + " p, %18, %19, %20;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x24x16_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x40x16 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x40x16_F32BF16BF16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[20]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %22, 0;\n" + "wgmma.mma_async.sync.aligned.m64n40k16.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19}," + " %20," + " %21," + " p, %23, %24, %25, %26;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x40x16_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x40x16 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x40x16_F32BF16BF16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[20]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %25, 0;\n" + "wgmma.mma_async.sync.aligned.m64n40k16.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19}," + "{%20, %21, %22, %23}," + " %24," + " p, %26, %27, %28;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x40x16_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x48x16 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x48x16_F32BF16BF16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[24]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %26, 0;\n" + "wgmma.mma_async.sync.aligned.m64n48k16.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + " %24," + " %25," + " p, %27, %28, %29, %30;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x48x16_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x48x16 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x48x16_F32BF16BF16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[24]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %29, 0;\n" + "wgmma.mma_async.sync.aligned.m64n48k16.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + "{%24, %25, %26, %27}," + " %28," + " p, %30, %31, %32;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x48x16_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x56x16 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x56x16_F32BF16BF16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[28]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %30, 0;\n" + "wgmma.mma_async.sync.aligned.m64n56k16.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27}," + " %28," + " %29," + " p, %31, %32, %33, %34;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x56x16_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x56x16 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x56x16_F32BF16BF16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[28]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %33, 0;\n" + "wgmma.mma_async.sync.aligned.m64n56k16.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27}," + "{%28, %29, %30, %31}," + " %32," + " p, %34, %35, %36;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x56x16_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x72x16 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x72x16_F32BF16BF16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[36]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %38, 0;\n" + "wgmma.mma_async.sync.aligned.m64n72k16.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35}," + " %36," + " %37," + " p, %39, %40, %41, %42;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x72x16_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x72x16 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x72x16_F32BF16BF16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[36]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %41, 0;\n" + "wgmma.mma_async.sync.aligned.m64n72k16.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35}," + "{%36, %37, %38, %39}," + " %40," + " p, %42, %43, %44;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x72x16_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x80x16 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x80x16_F32BF16BF16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[40]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %42, 0;\n" + "wgmma.mma_async.sync.aligned.m64n80k16.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + " %40," + " %41," + " p, %43, %44, %45, %46;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x80x16_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x80x16 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x80x16_F32BF16BF16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[40]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %45, 0;\n" + "wgmma.mma_async.sync.aligned.m64n80k16.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + "{%40, %41, %42, %43}," + " %44," + " p, %46, %47, %48;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x80x16_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x88x16 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x88x16_F32BF16BF16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[44]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %46, 0;\n" + "wgmma.mma_async.sync.aligned.m64n88k16.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43}," + " %44," + " %45," + " p, %47, %48, %49, %50;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x88x16_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x88x16 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x88x16_F32BF16BF16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[44]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %49, 0;\n" + "wgmma.mma_async.sync.aligned.m64n88k16.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43}," + "{%44, %45, %46, %47}," + " %48," + " p, %50, %51, %52;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x88x16_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x104x16 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x104x16_F32BF16BF16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[52]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %54, 0;\n" + "wgmma.mma_async.sync.aligned.m64n104k16.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51}," + " %52," + " %53," + " p, %55, %56, %57, %58;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x104x16_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x104x16 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x104x16_F32BF16BF16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[52]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %57, 0;\n" + "wgmma.mma_async.sync.aligned.m64n104k16.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51}," + "{%52, %53, %54, %55}," + " %56," + " p, %58, %59, %60;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x104x16_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x112x16 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x112x16_F32BF16BF16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[56]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %58, 0;\n" + "wgmma.mma_async.sync.aligned.m64n112k16.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + " %56," + " %57," + " p, %59, %60, %61, %62;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x112x16_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x112x16 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x112x16_F32BF16BF16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[56]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %61, 0;\n" + "wgmma.mma_async.sync.aligned.m64n112k16.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + "{%56, %57, %58, %59}," + " %60," + " p, %62, %63, %64;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x112x16_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x120x16 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x120x16_F32BF16BF16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[60]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %62, 0;\n" + "wgmma.mma_async.sync.aligned.m64n120k16.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59}," + " %60," + " %61," + " p, %63, %64, %65, %66;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x120x16_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x120x16 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x120x16_F32BF16BF16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[60]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %65, 0;\n" + "wgmma.mma_async.sync.aligned.m64n120k16.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59}," + "{%60, %61, %62, %63}," + " %64," + " p, %66, %67, %68;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x120x16_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x136x16 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x136x16_F32BF16BF16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[68]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %70, 0;\n" + "wgmma.mma_async.sync.aligned.m64n136k16.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67}," + " %68," + " %69," + " p, %71, %72, %73, %74;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x136x16_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x136x16 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x136x16_F32BF16BF16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[68]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %73, 0;\n" + "wgmma.mma_async.sync.aligned.m64n136k16.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67}," + "{%68, %69, %70, %71}," + " %72," + " p, %74, %75, %76;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x136x16_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x144x16 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x144x16_F32BF16BF16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[72]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %74, 0;\n" + "wgmma.mma_async.sync.aligned.m64n144k16.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}," + " %72," + " %73," + " p, %75, %76, %77, %78;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x144x16_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x144x16 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x144x16_F32BF16BF16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[72]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %77, 0;\n" + "wgmma.mma_async.sync.aligned.m64n144k16.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}," + "{%72, %73, %74, %75}," + " %76," + " p, %78, %79, %80;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x144x16_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x152x16 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x152x16_F32BF16BF16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[76]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %78, 0;\n" + "wgmma.mma_async.sync.aligned.m64n152k16.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75}," + " %76," + " %77," + " p, %79, %80, %81, %82;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x152x16_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x152x16 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x152x16_F32BF16BF16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[76]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %81, 0;\n" + "wgmma.mma_async.sync.aligned.m64n152k16.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75}," + "{%76, %77, %78, %79}," + " %80," + " p, %82, %83, %84;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x152x16_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x160x16 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x160x16_F32BF16BF16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[80]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %82, 0;\n" + "wgmma.mma_async.sync.aligned.m64n160k16.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}," + " %80," + " %81," + " p, %83, %84, %85, %86;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x160x16_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x160x16 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x160x16_F32BF16BF16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[80]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %85, 0;\n" + "wgmma.mma_async.sync.aligned.m64n160k16.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}," + "{%80, %81, %82, %83}," + " %84," + " p, %86, %87, %88;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x160x16_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x168x16 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x168x16_F32BF16BF16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[84]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %86, 0;\n" + "wgmma.mma_async.sync.aligned.m64n168k16.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83}," + " %84," + " %85," + " p, %87, %88, %89, %90;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x168x16_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x168x16 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x168x16_F32BF16BF16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[84]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %89, 0;\n" + "wgmma.mma_async.sync.aligned.m64n168k16.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83}," + "{%84, %85, %86, %87}," + " %88," + " p, %90, %91, %92;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x168x16_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x176x16 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x176x16_F32BF16BF16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[88]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %90, 0;\n" + "wgmma.mma_async.sync.aligned.m64n176k16.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87}," + " %88," + " %89," + " p, %91, %92, %93, %94;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x176x16_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x176x16 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x176x16_F32BF16BF16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[88]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %93, 0;\n" + "wgmma.mma_async.sync.aligned.m64n176k16.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87}," + "{%88, %89, %90, %91}," + " %92," + " p, %94, %95, %96;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x176x16_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x184x16 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x184x16_F32BF16BF16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[92]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + float & d88, float & d89, float & d90, float & d91, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %94, 0;\n" + "wgmma.mma_async.sync.aligned.m64n184k16.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91}," + " %92," + " %93," + " p, %95, %96, %97, %98;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), + "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x184x16_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x184x16 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x184x16_F32BF16BF16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[92]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + float & d88, float & d89, float & d90, float & d91, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %97, 0;\n" + "wgmma.mma_async.sync.aligned.m64n184k16.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91}," + "{%92, %93, %94, %95}," + " %96," + " p, %98, %99, %100;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), + "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x184x16_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x200x16 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x200x16_F32BF16BF16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[100]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %102, 0;\n" + "wgmma.mma_async.sync.aligned.m64n200k16.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99}," + " %100," + " %101," + " p, %103, %104, %105, %106;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x200x16_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x200x16 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x200x16_F32BF16BF16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[100]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %105, 0;\n" + "wgmma.mma_async.sync.aligned.m64n200k16.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99}," + "{%100, %101, %102, %103}," + " %104," + " p, %106, %107, %108;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x200x16_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x208x16 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x208x16_F32BF16BF16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[104]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %106, 0;\n" + "wgmma.mma_async.sync.aligned.m64n208k16.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103}," + " %104," + " %105," + " p, %107, %108, %109, %110;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x208x16_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x208x16 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x208x16_F32BF16BF16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[104]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %109, 0;\n" + "wgmma.mma_async.sync.aligned.m64n208k16.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103}," + "{%104, %105, %106, %107}," + " %108," + " p, %110, %111, %112;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x208x16_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x216x16 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x216x16_F32BF16BF16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[108]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %110, 0;\n" + "wgmma.mma_async.sync.aligned.m64n216k16.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107}," + " %108," + " %109," + " p, %111, %112, %113, %114;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x216x16_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x216x16 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x216x16_F32BF16BF16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[108]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %113, 0;\n" + "wgmma.mma_async.sync.aligned.m64n216k16.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107}," + "{%108, %109, %110, %111}," + " %112," + " p, %114, %115, %116;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x216x16_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x224x16 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x224x16_F32BF16BF16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[112]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %114, 0;\n" + "wgmma.mma_async.sync.aligned.m64n224k16.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111}," + " %112," + " %113," + " p, %115, %116, %117, %118;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x224x16_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x224x16 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x224x16_F32BF16BF16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[112]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %117, 0;\n" + "wgmma.mma_async.sync.aligned.m64n224k16.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111}," + "{%112, %113, %114, %115}," + " %116," + " p, %118, %119, %120;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x224x16_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x232x16 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x232x16_F32BF16BF16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[116]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %118, 0;\n" + "wgmma.mma_async.sync.aligned.m64n232k16.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115}," + " %116," + " %117," + " p, %119, %120, %121, %122;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x232x16_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x232x16 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x232x16_F32BF16BF16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[116]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %121, 0;\n" + "wgmma.mma_async.sync.aligned.m64n232k16.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115}," + "{%116, %117, %118, %119}," + " %120," + " p, %122, %123, %124;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x232x16_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x240x16 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x240x16_F32BF16BF16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[120]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %122, 0;\n" + "wgmma.mma_async.sync.aligned.m64n240k16.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119}," + " %120," + " %121," + " p, %123, %124, %125, %126;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x240x16_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x240x16 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x240x16_F32BF16BF16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[120]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %125, 0;\n" + "wgmma.mma_async.sync.aligned.m64n240k16.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119}," + "{%120, %121, %122, %123}," + " %124," + " p, %126, %127, %128;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x240x16_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x248x16 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x248x16_F32BF16BF16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[124]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + float & d120, float & d121, float & d122, float & d123, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %126, 0;\n" + "wgmma.mma_async.sync.aligned.m64n248k16.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123}," + " %124," + " %125," + " p, %127, %128, %129, %130;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), + "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x248x16_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x248x16 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x248x16_F32BF16BF16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[124]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + float & d120, float & d121, float & d122, float & d123, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %129, 0;\n" + "wgmma.mma_async.sync.aligned.m64n248k16.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123}," + "{%124, %125, %126, %127}," + " %128," + " p, %130, %131, %132;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), + "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x248x16_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x24x8 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x24x8_F32TF32TF32_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[12]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %14, 0;\n" + "wgmma.mma_async.sync.aligned.m64n24k8.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + " %12," + " %13," + " p, %15, %16;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x24x8_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x24x8 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x24x8_F32TF32TF32_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[12]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %17, 0;\n" + "wgmma.mma_async.sync.aligned.m64n24k8.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + "{%12, %13, %14, %15}," + " %16," + " p, %18, %19;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x24x8_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x40x8 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x40x8_F32TF32TF32_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[20]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %22, 0;\n" + "wgmma.mma_async.sync.aligned.m64n40k8.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19}," + " %20," + " %21," + " p, %23, %24;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x40x8_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x40x8 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x40x8_F32TF32TF32_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[20]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %25, 0;\n" + "wgmma.mma_async.sync.aligned.m64n40k8.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19}," + "{%20, %21, %22, %23}," + " %24," + " p, %26, %27;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x40x8_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x48x8 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x48x8_F32TF32TF32_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[24]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %26, 0;\n" + "wgmma.mma_async.sync.aligned.m64n48k8.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + " %24," + " %25," + " p, %27, %28;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x48x8_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x48x8 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x48x8_F32TF32TF32_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[24]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %29, 0;\n" + "wgmma.mma_async.sync.aligned.m64n48k8.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + "{%24, %25, %26, %27}," + " %28," + " p, %30, %31;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x48x8_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x56x8 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x56x8_F32TF32TF32_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[28]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %30, 0;\n" + "wgmma.mma_async.sync.aligned.m64n56k8.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27}," + " %28," + " %29," + " p, %31, %32;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x56x8_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x56x8 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x56x8_F32TF32TF32_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[28]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %33, 0;\n" + "wgmma.mma_async.sync.aligned.m64n56k8.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27}," + "{%28, %29, %30, %31}," + " %32," + " p, %34, %35;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x56x8_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x72x8 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x72x8_F32TF32TF32_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[36]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %38, 0;\n" + "wgmma.mma_async.sync.aligned.m64n72k8.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35}," + " %36," + " %37," + " p, %39, %40;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x72x8_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x72x8 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x72x8_F32TF32TF32_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[36]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %41, 0;\n" + "wgmma.mma_async.sync.aligned.m64n72k8.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35}," + "{%36, %37, %38, %39}," + " %40," + " p, %42, %43;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x72x8_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x80x8 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x80x8_F32TF32TF32_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[40]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %42, 0;\n" + "wgmma.mma_async.sync.aligned.m64n80k8.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + " %40," + " %41," + " p, %43, %44;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x80x8_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x80x8 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x80x8_F32TF32TF32_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[40]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %45, 0;\n" + "wgmma.mma_async.sync.aligned.m64n80k8.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + "{%40, %41, %42, %43}," + " %44," + " p, %46, %47;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x80x8_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x88x8 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x88x8_F32TF32TF32_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[44]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %46, 0;\n" + "wgmma.mma_async.sync.aligned.m64n88k8.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43}," + " %44," + " %45," + " p, %47, %48;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x88x8_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x88x8 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x88x8_F32TF32TF32_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[44]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %49, 0;\n" + "wgmma.mma_async.sync.aligned.m64n88k8.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43}," + "{%44, %45, %46, %47}," + " %48," + " p, %50, %51;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x88x8_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x104x8 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x104x8_F32TF32TF32_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[52]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %54, 0;\n" + "wgmma.mma_async.sync.aligned.m64n104k8.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51}," + " %52," + " %53," + " p, %55, %56;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x104x8_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x104x8 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x104x8_F32TF32TF32_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[52]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %57, 0;\n" + "wgmma.mma_async.sync.aligned.m64n104k8.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51}," + "{%52, %53, %54, %55}," + " %56," + " p, %58, %59;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x104x8_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x112x8 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x112x8_F32TF32TF32_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[56]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %58, 0;\n" + "wgmma.mma_async.sync.aligned.m64n112k8.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + " %56," + " %57," + " p, %59, %60;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x112x8_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x112x8 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x112x8_F32TF32TF32_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[56]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %61, 0;\n" + "wgmma.mma_async.sync.aligned.m64n112k8.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + "{%56, %57, %58, %59}," + " %60," + " p, %62, %63;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x112x8_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x120x8 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x120x8_F32TF32TF32_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[60]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %62, 0;\n" + "wgmma.mma_async.sync.aligned.m64n120k8.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59}," + " %60," + " %61," + " p, %63, %64;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x120x8_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x120x8 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x120x8_F32TF32TF32_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[60]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %65, 0;\n" + "wgmma.mma_async.sync.aligned.m64n120k8.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59}," + "{%60, %61, %62, %63}," + " %64," + " p, %66, %67;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x120x8_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x136x8 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x136x8_F32TF32TF32_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[68]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %70, 0;\n" + "wgmma.mma_async.sync.aligned.m64n136k8.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67}," + " %68," + " %69," + " p, %71, %72;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x136x8_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x136x8 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x136x8_F32TF32TF32_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[68]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %73, 0;\n" + "wgmma.mma_async.sync.aligned.m64n136k8.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67}," + "{%68, %69, %70, %71}," + " %72," + " p, %74, %75;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x136x8_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x144x8 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x144x8_F32TF32TF32_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[72]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %74, 0;\n" + "wgmma.mma_async.sync.aligned.m64n144k8.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}," + " %72," + " %73," + " p, %75, %76;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x144x8_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x144x8 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x144x8_F32TF32TF32_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[72]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %77, 0;\n" + "wgmma.mma_async.sync.aligned.m64n144k8.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}," + "{%72, %73, %74, %75}," + " %76," + " p, %78, %79;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x144x8_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x152x8 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x152x8_F32TF32TF32_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[76]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %78, 0;\n" + "wgmma.mma_async.sync.aligned.m64n152k8.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75}," + " %76," + " %77," + " p, %79, %80;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x152x8_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x152x8 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x152x8_F32TF32TF32_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[76]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %81, 0;\n" + "wgmma.mma_async.sync.aligned.m64n152k8.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75}," + "{%76, %77, %78, %79}," + " %80," + " p, %82, %83;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x152x8_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x160x8 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x160x8_F32TF32TF32_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[80]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %82, 0;\n" + "wgmma.mma_async.sync.aligned.m64n160k8.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}," + " %80," + " %81," + " p, %83, %84;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x160x8_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x160x8 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x160x8_F32TF32TF32_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[80]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %85, 0;\n" + "wgmma.mma_async.sync.aligned.m64n160k8.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}," + "{%80, %81, %82, %83}," + " %84," + " p, %86, %87;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x160x8_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x168x8 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x168x8_F32TF32TF32_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[84]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %86, 0;\n" + "wgmma.mma_async.sync.aligned.m64n168k8.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83}," + " %84," + " %85," + " p, %87, %88;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x168x8_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x168x8 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x168x8_F32TF32TF32_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[84]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %89, 0;\n" + "wgmma.mma_async.sync.aligned.m64n168k8.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83}," + "{%84, %85, %86, %87}," + " %88," + " p, %90, %91;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x168x8_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x176x8 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x176x8_F32TF32TF32_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[88]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %90, 0;\n" + "wgmma.mma_async.sync.aligned.m64n176k8.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87}," + " %88," + " %89," + " p, %91, %92;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x176x8_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x176x8 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x176x8_F32TF32TF32_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[88]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %93, 0;\n" + "wgmma.mma_async.sync.aligned.m64n176k8.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87}," + "{%88, %89, %90, %91}," + " %92," + " p, %94, %95;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x176x8_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x184x8 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x184x8_F32TF32TF32_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[92]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + float & d88, float & d89, float & d90, float & d91, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %94, 0;\n" + "wgmma.mma_async.sync.aligned.m64n184k8.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91}," + " %92," + " %93," + " p, %95, %96;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), + "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x184x8_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x184x8 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x184x8_F32TF32TF32_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[92]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + float & d88, float & d89, float & d90, float & d91, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %97, 0;\n" + "wgmma.mma_async.sync.aligned.m64n184k8.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91}," + "{%92, %93, %94, %95}," + " %96," + " p, %98, %99;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), + "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x184x8_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x200x8 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x200x8_F32TF32TF32_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[100]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %102, 0;\n" + "wgmma.mma_async.sync.aligned.m64n200k8.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99}," + " %100," + " %101," + " p, %103, %104;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x200x8_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x200x8 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x200x8_F32TF32TF32_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[100]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %105, 0;\n" + "wgmma.mma_async.sync.aligned.m64n200k8.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99}," + "{%100, %101, %102, %103}," + " %104," + " p, %106, %107;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x200x8_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x208x8 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x208x8_F32TF32TF32_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[104]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %106, 0;\n" + "wgmma.mma_async.sync.aligned.m64n208k8.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103}," + " %104," + " %105," + " p, %107, %108;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x208x8_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x208x8 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x208x8_F32TF32TF32_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[104]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %109, 0;\n" + "wgmma.mma_async.sync.aligned.m64n208k8.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103}," + "{%104, %105, %106, %107}," + " %108," + " p, %110, %111;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x208x8_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x216x8 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x216x8_F32TF32TF32_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[108]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %110, 0;\n" + "wgmma.mma_async.sync.aligned.m64n216k8.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107}," + " %108," + " %109," + " p, %111, %112;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x216x8_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x216x8 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x216x8_F32TF32TF32_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[108]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %113, 0;\n" + "wgmma.mma_async.sync.aligned.m64n216k8.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107}," + "{%108, %109, %110, %111}," + " %112," + " p, %114, %115;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x216x8_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x224x8 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x224x8_F32TF32TF32_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[112]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %114, 0;\n" + "wgmma.mma_async.sync.aligned.m64n224k8.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111}," + " %112," + " %113," + " p, %115, %116;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x224x8_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x224x8 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x224x8_F32TF32TF32_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[112]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %117, 0;\n" + "wgmma.mma_async.sync.aligned.m64n224k8.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111}," + "{%112, %113, %114, %115}," + " %116," + " p, %118, %119;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x224x8_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x232x8 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x232x8_F32TF32TF32_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[116]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %118, 0;\n" + "wgmma.mma_async.sync.aligned.m64n232k8.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115}," + " %116," + " %117," + " p, %119, %120;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x232x8_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x232x8 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x232x8_F32TF32TF32_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[116]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %121, 0;\n" + "wgmma.mma_async.sync.aligned.m64n232k8.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115}," + "{%116, %117, %118, %119}," + " %120," + " p, %122, %123;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x232x8_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x240x8 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x240x8_F32TF32TF32_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[120]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %122, 0;\n" + "wgmma.mma_async.sync.aligned.m64n240k8.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119}," + " %120," + " %121," + " p, %123, %124;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x240x8_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x240x8 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x240x8_F32TF32TF32_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[120]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %125, 0;\n" + "wgmma.mma_async.sync.aligned.m64n240k8.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119}," + "{%120, %121, %122, %123}," + " %124," + " p, %126, %127;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x240x8_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x248x8 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x248x8_F32TF32TF32_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[124]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + float & d120, float & d121, float & d122, float & d123, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %126, 0;\n" + "wgmma.mma_async.sync.aligned.m64n248k8.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123}," + " %124," + " %125," + " p, %127, %128;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), + "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x248x8_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x248x8 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x248x8_F32TF32TF32_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[124]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + float & d120, float & d121, float & d122, float & d123, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %129, 0;\n" + "wgmma.mma_async.sync.aligned.m64n248k8.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123}," + "{%124, %125, %126, %127}," + " %128," + " p, %130, %131;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), + "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x248x8_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x24x32 TN S32+=S8*S8 +struct MMA_64x24x32_S32S8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[12]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %14, 0;\n" + "wgmma.mma_async.sync.aligned.m64n24k32.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + " %12," + " %13," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x24x32_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x24x32 TN S32+=S8*S8 +struct MMA_64x24x32_S32S8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[12]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %14, 0;\n" + "wgmma.mma_async.sync.aligned.m64n24k32.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + " %12," + " %13," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x24x32_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x48x32 TN S32+=S8*S8 +struct MMA_64x48x32_S32S8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[24]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %26, 0;\n" + "wgmma.mma_async.sync.aligned.m64n48k32.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + " %24," + " %25," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x48x32_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x48x32 TN S32+=S8*S8 +struct MMA_64x48x32_S32S8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[24]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %26, 0;\n" + "wgmma.mma_async.sync.aligned.m64n48k32.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + " %24," + " %25," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x48x32_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x80x32 TN S32+=S8*S8 +struct MMA_64x80x32_S32S8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[40]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %42, 0;\n" + "wgmma.mma_async.sync.aligned.m64n80k32.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + " %40," + " %41," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x80x32_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x80x32 TN S32+=S8*S8 +struct MMA_64x80x32_S32S8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[40]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %42, 0;\n" + "wgmma.mma_async.sync.aligned.m64n80k32.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + " %40," + " %41," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x80x32_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x112x32 TN S32+=S8*S8 +struct MMA_64x112x32_S32S8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[56]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %58, 0;\n" + "wgmma.mma_async.sync.aligned.m64n112k32.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + " %56," + " %57," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x112x32_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x112x32 TN S32+=S8*S8 +struct MMA_64x112x32_S32S8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[56]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %58, 0;\n" + "wgmma.mma_async.sync.aligned.m64n112k32.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + " %56," + " %57," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x112x32_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x144x32 TN S32+=S8*S8 +struct MMA_64x144x32_S32S8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[72]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %74, 0;\n" + "wgmma.mma_async.sync.aligned.m64n144k32.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}," + " %72," + " %73," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x144x32_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x144x32 TN S32+=S8*S8 +struct MMA_64x144x32_S32S8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[72]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %74, 0;\n" + "wgmma.mma_async.sync.aligned.m64n144k32.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}," + " %72," + " %73," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x144x32_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x160x32 TN S32+=S8*S8 +struct MMA_64x160x32_S32S8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[80]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %82, 0;\n" + "wgmma.mma_async.sync.aligned.m64n160k32.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}," + " %80," + " %81," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x160x32_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x160x32 TN S32+=S8*S8 +struct MMA_64x160x32_S32S8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[80]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %82, 0;\n" + "wgmma.mma_async.sync.aligned.m64n160k32.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}," + " %80," + " %81," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x160x32_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x176x32 TN S32+=S8*S8 +struct MMA_64x176x32_S32S8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[88]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %90, 0;\n" + "wgmma.mma_async.sync.aligned.m64n176k32.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87}," + " %88," + " %89," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x176x32_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x176x32 TN S32+=S8*S8 +struct MMA_64x176x32_S32S8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[88]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %90, 0;\n" + "wgmma.mma_async.sync.aligned.m64n176k32.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87}," + " %88," + " %89," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x176x32_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x208x32 TN S32+=S8*S8 +struct MMA_64x208x32_S32S8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[104]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %106, 0;\n" + "wgmma.mma_async.sync.aligned.m64n208k32.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103}," + " %104," + " %105," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x208x32_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x208x32 TN S32+=S8*S8 +struct MMA_64x208x32_S32S8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[104]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %106, 0;\n" + "wgmma.mma_async.sync.aligned.m64n208k32.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103}," + " %104," + " %105," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x208x32_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x224x32 TN S32+=S8*S8 +struct MMA_64x224x32_S32S8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[112]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %114, 0;\n" + "wgmma.mma_async.sync.aligned.m64n224k32.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111}," + " %112," + " %113," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x224x32_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x224x32 TN S32+=S8*S8 +struct MMA_64x224x32_S32S8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[112]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %114, 0;\n" + "wgmma.mma_async.sync.aligned.m64n224k32.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111}," + " %112," + " %113," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x224x32_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x240x32 TN S32+=S8*S8 +struct MMA_64x240x32_S32S8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[120]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %122, 0;\n" + "wgmma.mma_async.sync.aligned.m64n240k32.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119}," + " %120," + " %121," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x240x32_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x240x32 TN S32+=S8*S8 +struct MMA_64x240x32_S32S8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[120]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %122, 0;\n" + "wgmma.mma_async.sync.aligned.m64n240k32.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119}," + " %120," + " %121," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x240x32_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x24x32 TN S32+=S8*S8 +struct MMA_64x24x32_S32S8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[12]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %17, 0;\n" + "wgmma.mma_async.sync.aligned.m64n24k32.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + "{%12, %13, %14, %15}," + " %16," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x24x32_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x24x32 TN S32+=S8*S8 +struct MMA_64x24x32_S32S8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[12]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %17, 0;\n" + "wgmma.mma_async.sync.aligned.m64n24k32.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + "{%12, %13, %14, %15}," + " %16," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x24x32_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x48x32 TN S32+=S8*S8 +struct MMA_64x48x32_S32S8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[24]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %29, 0;\n" + "wgmma.mma_async.sync.aligned.m64n48k32.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + "{%24, %25, %26, %27}," + " %28," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x48x32_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x48x32 TN S32+=S8*S8 +struct MMA_64x48x32_S32S8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[24]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %29, 0;\n" + "wgmma.mma_async.sync.aligned.m64n48k32.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + "{%24, %25, %26, %27}," + " %28," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x48x32_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x80x32 TN S32+=S8*S8 +struct MMA_64x80x32_S32S8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[40]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %45, 0;\n" + "wgmma.mma_async.sync.aligned.m64n80k32.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + "{%40, %41, %42, %43}," + " %44," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x80x32_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x80x32 TN S32+=S8*S8 +struct MMA_64x80x32_S32S8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[40]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %45, 0;\n" + "wgmma.mma_async.sync.aligned.m64n80k32.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + "{%40, %41, %42, %43}," + " %44," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x80x32_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x112x32 TN S32+=S8*S8 +struct MMA_64x112x32_S32S8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[56]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %61, 0;\n" + "wgmma.mma_async.sync.aligned.m64n112k32.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + "{%56, %57, %58, %59}," + " %60," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x112x32_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x112x32 TN S32+=S8*S8 +struct MMA_64x112x32_S32S8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[56]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %61, 0;\n" + "wgmma.mma_async.sync.aligned.m64n112k32.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + "{%56, %57, %58, %59}," + " %60," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x112x32_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x144x32 TN S32+=S8*S8 +struct MMA_64x144x32_S32S8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[72]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %77, 0;\n" + "wgmma.mma_async.sync.aligned.m64n144k32.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}," + "{%72, %73, %74, %75}," + " %76," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x144x32_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x144x32 TN S32+=S8*S8 +struct MMA_64x144x32_S32S8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[72]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %77, 0;\n" + "wgmma.mma_async.sync.aligned.m64n144k32.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}," + "{%72, %73, %74, %75}," + " %76," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x144x32_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x160x32 TN S32+=S8*S8 +struct MMA_64x160x32_S32S8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[80]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %85, 0;\n" + "wgmma.mma_async.sync.aligned.m64n160k32.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}," + "{%80, %81, %82, %83}," + " %84," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x160x32_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x160x32 TN S32+=S8*S8 +struct MMA_64x160x32_S32S8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[80]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %85, 0;\n" + "wgmma.mma_async.sync.aligned.m64n160k32.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}," + "{%80, %81, %82, %83}," + " %84," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x160x32_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x176x32 TN S32+=S8*S8 +struct MMA_64x176x32_S32S8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[88]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %93, 0;\n" + "wgmma.mma_async.sync.aligned.m64n176k32.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87}," + "{%88, %89, %90, %91}," + " %92," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x176x32_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x176x32 TN S32+=S8*S8 +struct MMA_64x176x32_S32S8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[88]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %93, 0;\n" + "wgmma.mma_async.sync.aligned.m64n176k32.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87}," + "{%88, %89, %90, %91}," + " %92," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x176x32_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x208x32 TN S32+=S8*S8 +struct MMA_64x208x32_S32S8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[104]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %109, 0;\n" + "wgmma.mma_async.sync.aligned.m64n208k32.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103}," + "{%104, %105, %106, %107}," + " %108," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x208x32_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x208x32 TN S32+=S8*S8 +struct MMA_64x208x32_S32S8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[104]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %109, 0;\n" + "wgmma.mma_async.sync.aligned.m64n208k32.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103}," + "{%104, %105, %106, %107}," + " %108," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x208x32_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x224x32 TN S32+=S8*S8 +struct MMA_64x224x32_S32S8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[112]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %117, 0;\n" + "wgmma.mma_async.sync.aligned.m64n224k32.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111}," + "{%112, %113, %114, %115}," + " %116," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x224x32_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x224x32 TN S32+=S8*S8 +struct MMA_64x224x32_S32S8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[112]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %117, 0;\n" + "wgmma.mma_async.sync.aligned.m64n224k32.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111}," + "{%112, %113, %114, %115}," + " %116," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x224x32_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x240x32 TN S32+=S8*S8 +struct MMA_64x240x32_S32S8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[120]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %125, 0;\n" + "wgmma.mma_async.sync.aligned.m64n240k32.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119}," + "{%120, %121, %122, %123}," + " %124," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x240x32_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x240x32 TN S32+=S8*S8 +struct MMA_64x240x32_S32S8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[120]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %125, 0;\n" + "wgmma.mma_async.sync.aligned.m64n240k32.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119}," + "{%120, %121, %122, %123}," + " %124," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x240x32_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x24x32 TN S32+=S8*U8 +struct MMA_64x24x32_S32S8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[12]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %14, 0;\n" + "wgmma.mma_async.sync.aligned.m64n24k32.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + " %12," + " %13," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x24x32_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x24x32 TN S32+=S8*U8 +struct MMA_64x24x32_S32S8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[12]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %14, 0;\n" + "wgmma.mma_async.sync.aligned.m64n24k32.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + " %12," + " %13," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x24x32_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x48x32 TN S32+=S8*U8 +struct MMA_64x48x32_S32S8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[24]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %26, 0;\n" + "wgmma.mma_async.sync.aligned.m64n48k32.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + " %24," + " %25," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x48x32_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x48x32 TN S32+=S8*U8 +struct MMA_64x48x32_S32S8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[24]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %26, 0;\n" + "wgmma.mma_async.sync.aligned.m64n48k32.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + " %24," + " %25," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x48x32_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x80x32 TN S32+=S8*U8 +struct MMA_64x80x32_S32S8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[40]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %42, 0;\n" + "wgmma.mma_async.sync.aligned.m64n80k32.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + " %40," + " %41," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x80x32_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x80x32 TN S32+=S8*U8 +struct MMA_64x80x32_S32S8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[40]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %42, 0;\n" + "wgmma.mma_async.sync.aligned.m64n80k32.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + " %40," + " %41," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x80x32_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x112x32 TN S32+=S8*U8 +struct MMA_64x112x32_S32S8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[56]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %58, 0;\n" + "wgmma.mma_async.sync.aligned.m64n112k32.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + " %56," + " %57," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x112x32_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x112x32 TN S32+=S8*U8 +struct MMA_64x112x32_S32S8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[56]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %58, 0;\n" + "wgmma.mma_async.sync.aligned.m64n112k32.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + " %56," + " %57," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x112x32_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x144x32 TN S32+=S8*U8 +struct MMA_64x144x32_S32S8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[72]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %74, 0;\n" + "wgmma.mma_async.sync.aligned.m64n144k32.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}," + " %72," + " %73," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x144x32_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x144x32 TN S32+=S8*U8 +struct MMA_64x144x32_S32S8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[72]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %74, 0;\n" + "wgmma.mma_async.sync.aligned.m64n144k32.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}," + " %72," + " %73," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x144x32_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x160x32 TN S32+=S8*U8 +struct MMA_64x160x32_S32S8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[80]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %82, 0;\n" + "wgmma.mma_async.sync.aligned.m64n160k32.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}," + " %80," + " %81," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x160x32_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x160x32 TN S32+=S8*U8 +struct MMA_64x160x32_S32S8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[80]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %82, 0;\n" + "wgmma.mma_async.sync.aligned.m64n160k32.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}," + " %80," + " %81," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x160x32_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x176x32 TN S32+=S8*U8 +struct MMA_64x176x32_S32S8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[88]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %90, 0;\n" + "wgmma.mma_async.sync.aligned.m64n176k32.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87}," + " %88," + " %89," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x176x32_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x176x32 TN S32+=S8*U8 +struct MMA_64x176x32_S32S8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[88]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %90, 0;\n" + "wgmma.mma_async.sync.aligned.m64n176k32.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87}," + " %88," + " %89," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x176x32_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x208x32 TN S32+=S8*U8 +struct MMA_64x208x32_S32S8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[104]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %106, 0;\n" + "wgmma.mma_async.sync.aligned.m64n208k32.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103}," + " %104," + " %105," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x208x32_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x208x32 TN S32+=S8*U8 +struct MMA_64x208x32_S32S8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[104]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %106, 0;\n" + "wgmma.mma_async.sync.aligned.m64n208k32.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103}," + " %104," + " %105," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x208x32_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x224x32 TN S32+=S8*U8 +struct MMA_64x224x32_S32S8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[112]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %114, 0;\n" + "wgmma.mma_async.sync.aligned.m64n224k32.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111}," + " %112," + " %113," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x224x32_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x224x32 TN S32+=S8*U8 +struct MMA_64x224x32_S32S8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[112]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %114, 0;\n" + "wgmma.mma_async.sync.aligned.m64n224k32.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111}," + " %112," + " %113," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x224x32_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x240x32 TN S32+=S8*U8 +struct MMA_64x240x32_S32S8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[120]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %122, 0;\n" + "wgmma.mma_async.sync.aligned.m64n240k32.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119}," + " %120," + " %121," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x240x32_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x240x32 TN S32+=S8*U8 +struct MMA_64x240x32_S32S8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[120]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %122, 0;\n" + "wgmma.mma_async.sync.aligned.m64n240k32.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119}," + " %120," + " %121," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x240x32_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x24x32 TN S32+=S8*U8 +struct MMA_64x24x32_S32S8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[12]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %17, 0;\n" + "wgmma.mma_async.sync.aligned.m64n24k32.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + "{%12, %13, %14, %15}," + " %16," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x24x32_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x24x32 TN S32+=S8*U8 +struct MMA_64x24x32_S32S8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[12]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %17, 0;\n" + "wgmma.mma_async.sync.aligned.m64n24k32.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + "{%12, %13, %14, %15}," + " %16," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x24x32_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x48x32 TN S32+=S8*U8 +struct MMA_64x48x32_S32S8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[24]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %29, 0;\n" + "wgmma.mma_async.sync.aligned.m64n48k32.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + "{%24, %25, %26, %27}," + " %28," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x48x32_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x48x32 TN S32+=S8*U8 +struct MMA_64x48x32_S32S8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[24]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %29, 0;\n" + "wgmma.mma_async.sync.aligned.m64n48k32.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + "{%24, %25, %26, %27}," + " %28," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x48x32_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x80x32 TN S32+=S8*U8 +struct MMA_64x80x32_S32S8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[40]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %45, 0;\n" + "wgmma.mma_async.sync.aligned.m64n80k32.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + "{%40, %41, %42, %43}," + " %44," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x80x32_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x80x32 TN S32+=S8*U8 +struct MMA_64x80x32_S32S8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[40]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %45, 0;\n" + "wgmma.mma_async.sync.aligned.m64n80k32.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + "{%40, %41, %42, %43}," + " %44," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x80x32_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x112x32 TN S32+=S8*U8 +struct MMA_64x112x32_S32S8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[56]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %61, 0;\n" + "wgmma.mma_async.sync.aligned.m64n112k32.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + "{%56, %57, %58, %59}," + " %60," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x112x32_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x112x32 TN S32+=S8*U8 +struct MMA_64x112x32_S32S8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[56]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %61, 0;\n" + "wgmma.mma_async.sync.aligned.m64n112k32.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + "{%56, %57, %58, %59}," + " %60," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x112x32_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x144x32 TN S32+=S8*U8 +struct MMA_64x144x32_S32S8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[72]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %77, 0;\n" + "wgmma.mma_async.sync.aligned.m64n144k32.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}," + "{%72, %73, %74, %75}," + " %76," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x144x32_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x144x32 TN S32+=S8*U8 +struct MMA_64x144x32_S32S8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[72]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %77, 0;\n" + "wgmma.mma_async.sync.aligned.m64n144k32.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}," + "{%72, %73, %74, %75}," + " %76," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x144x32_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x160x32 TN S32+=S8*U8 +struct MMA_64x160x32_S32S8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[80]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %85, 0;\n" + "wgmma.mma_async.sync.aligned.m64n160k32.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}," + "{%80, %81, %82, %83}," + " %84," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x160x32_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x160x32 TN S32+=S8*U8 +struct MMA_64x160x32_S32S8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[80]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %85, 0;\n" + "wgmma.mma_async.sync.aligned.m64n160k32.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}," + "{%80, %81, %82, %83}," + " %84," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x160x32_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x176x32 TN S32+=S8*U8 +struct MMA_64x176x32_S32S8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[88]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %93, 0;\n" + "wgmma.mma_async.sync.aligned.m64n176k32.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87}," + "{%88, %89, %90, %91}," + " %92," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x176x32_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x176x32 TN S32+=S8*U8 +struct MMA_64x176x32_S32S8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[88]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %93, 0;\n" + "wgmma.mma_async.sync.aligned.m64n176k32.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87}," + "{%88, %89, %90, %91}," + " %92," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x176x32_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x208x32 TN S32+=S8*U8 +struct MMA_64x208x32_S32S8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[104]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %109, 0;\n" + "wgmma.mma_async.sync.aligned.m64n208k32.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103}," + "{%104, %105, %106, %107}," + " %108," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x208x32_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x208x32 TN S32+=S8*U8 +struct MMA_64x208x32_S32S8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[104]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %109, 0;\n" + "wgmma.mma_async.sync.aligned.m64n208k32.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103}," + "{%104, %105, %106, %107}," + " %108," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x208x32_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x224x32 TN S32+=S8*U8 +struct MMA_64x224x32_S32S8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[112]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %117, 0;\n" + "wgmma.mma_async.sync.aligned.m64n224k32.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111}," + "{%112, %113, %114, %115}," + " %116," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x224x32_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x224x32 TN S32+=S8*U8 +struct MMA_64x224x32_S32S8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[112]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %117, 0;\n" + "wgmma.mma_async.sync.aligned.m64n224k32.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111}," + "{%112, %113, %114, %115}," + " %116," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x224x32_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x240x32 TN S32+=S8*U8 +struct MMA_64x240x32_S32S8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[120]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %125, 0;\n" + "wgmma.mma_async.sync.aligned.m64n240k32.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119}," + "{%120, %121, %122, %123}," + " %124," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x240x32_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x240x32 TN S32+=S8*U8 +struct MMA_64x240x32_S32S8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[120]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %125, 0;\n" + "wgmma.mma_async.sync.aligned.m64n240k32.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119}," + "{%120, %121, %122, %123}," + " %124," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x240x32_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x24x32 TN S32+=U8*S8 +struct MMA_64x24x32_S32U8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[12]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %14, 0;\n" + "wgmma.mma_async.sync.aligned.m64n24k32.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + " %12," + " %13," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x24x32_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x24x32 TN S32+=U8*S8 +struct MMA_64x24x32_S32U8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[12]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %14, 0;\n" + "wgmma.mma_async.sync.aligned.m64n24k32.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + " %12," + " %13," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x24x32_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x48x32 TN S32+=U8*S8 +struct MMA_64x48x32_S32U8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[24]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %26, 0;\n" + "wgmma.mma_async.sync.aligned.m64n48k32.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + " %24," + " %25," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x48x32_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x48x32 TN S32+=U8*S8 +struct MMA_64x48x32_S32U8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[24]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %26, 0;\n" + "wgmma.mma_async.sync.aligned.m64n48k32.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + " %24," + " %25," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x48x32_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x80x32 TN S32+=U8*S8 +struct MMA_64x80x32_S32U8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[40]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %42, 0;\n" + "wgmma.mma_async.sync.aligned.m64n80k32.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + " %40," + " %41," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x80x32_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x80x32 TN S32+=U8*S8 +struct MMA_64x80x32_S32U8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[40]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %42, 0;\n" + "wgmma.mma_async.sync.aligned.m64n80k32.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + " %40," + " %41," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x80x32_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x112x32 TN S32+=U8*S8 +struct MMA_64x112x32_S32U8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[56]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %58, 0;\n" + "wgmma.mma_async.sync.aligned.m64n112k32.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + " %56," + " %57," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x112x32_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x112x32 TN S32+=U8*S8 +struct MMA_64x112x32_S32U8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[56]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %58, 0;\n" + "wgmma.mma_async.sync.aligned.m64n112k32.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + " %56," + " %57," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x112x32_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x144x32 TN S32+=U8*S8 +struct MMA_64x144x32_S32U8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[72]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %74, 0;\n" + "wgmma.mma_async.sync.aligned.m64n144k32.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}," + " %72," + " %73," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x144x32_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x144x32 TN S32+=U8*S8 +struct MMA_64x144x32_S32U8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[72]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %74, 0;\n" + "wgmma.mma_async.sync.aligned.m64n144k32.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}," + " %72," + " %73," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x144x32_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x160x32 TN S32+=U8*S8 +struct MMA_64x160x32_S32U8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[80]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %82, 0;\n" + "wgmma.mma_async.sync.aligned.m64n160k32.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}," + " %80," + " %81," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x160x32_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x160x32 TN S32+=U8*S8 +struct MMA_64x160x32_S32U8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[80]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %82, 0;\n" + "wgmma.mma_async.sync.aligned.m64n160k32.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}," + " %80," + " %81," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x160x32_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x176x32 TN S32+=U8*S8 +struct MMA_64x176x32_S32U8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[88]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %90, 0;\n" + "wgmma.mma_async.sync.aligned.m64n176k32.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87}," + " %88," + " %89," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x176x32_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x176x32 TN S32+=U8*S8 +struct MMA_64x176x32_S32U8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[88]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %90, 0;\n" + "wgmma.mma_async.sync.aligned.m64n176k32.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87}," + " %88," + " %89," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x176x32_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x208x32 TN S32+=U8*S8 +struct MMA_64x208x32_S32U8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[104]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %106, 0;\n" + "wgmma.mma_async.sync.aligned.m64n208k32.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103}," + " %104," + " %105," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x208x32_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x208x32 TN S32+=U8*S8 +struct MMA_64x208x32_S32U8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[104]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %106, 0;\n" + "wgmma.mma_async.sync.aligned.m64n208k32.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103}," + " %104," + " %105," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x208x32_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x224x32 TN S32+=U8*S8 +struct MMA_64x224x32_S32U8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[112]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %114, 0;\n" + "wgmma.mma_async.sync.aligned.m64n224k32.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111}," + " %112," + " %113," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x224x32_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x224x32 TN S32+=U8*S8 +struct MMA_64x224x32_S32U8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[112]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %114, 0;\n" + "wgmma.mma_async.sync.aligned.m64n224k32.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111}," + " %112," + " %113," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x224x32_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x240x32 TN S32+=U8*S8 +struct MMA_64x240x32_S32U8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[120]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %122, 0;\n" + "wgmma.mma_async.sync.aligned.m64n240k32.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119}," + " %120," + " %121," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x240x32_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x240x32 TN S32+=U8*S8 +struct MMA_64x240x32_S32U8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[120]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %122, 0;\n" + "wgmma.mma_async.sync.aligned.m64n240k32.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119}," + " %120," + " %121," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x240x32_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x24x32 TN S32+=U8*S8 +struct MMA_64x24x32_S32U8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[12]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %17, 0;\n" + "wgmma.mma_async.sync.aligned.m64n24k32.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + "{%12, %13, %14, %15}," + " %16," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x24x32_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x24x32 TN S32+=U8*S8 +struct MMA_64x24x32_S32U8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[12]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %17, 0;\n" + "wgmma.mma_async.sync.aligned.m64n24k32.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + "{%12, %13, %14, %15}," + " %16," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x24x32_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x48x32 TN S32+=U8*S8 +struct MMA_64x48x32_S32U8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[24]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %29, 0;\n" + "wgmma.mma_async.sync.aligned.m64n48k32.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + "{%24, %25, %26, %27}," + " %28," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x48x32_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x48x32 TN S32+=U8*S8 +struct MMA_64x48x32_S32U8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[24]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %29, 0;\n" + "wgmma.mma_async.sync.aligned.m64n48k32.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + "{%24, %25, %26, %27}," + " %28," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x48x32_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x80x32 TN S32+=U8*S8 +struct MMA_64x80x32_S32U8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[40]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %45, 0;\n" + "wgmma.mma_async.sync.aligned.m64n80k32.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + "{%40, %41, %42, %43}," + " %44," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x80x32_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x80x32 TN S32+=U8*S8 +struct MMA_64x80x32_S32U8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[40]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %45, 0;\n" + "wgmma.mma_async.sync.aligned.m64n80k32.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + "{%40, %41, %42, %43}," + " %44," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x80x32_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x112x32 TN S32+=U8*S8 +struct MMA_64x112x32_S32U8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[56]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %61, 0;\n" + "wgmma.mma_async.sync.aligned.m64n112k32.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + "{%56, %57, %58, %59}," + " %60," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x112x32_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x112x32 TN S32+=U8*S8 +struct MMA_64x112x32_S32U8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[56]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %61, 0;\n" + "wgmma.mma_async.sync.aligned.m64n112k32.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + "{%56, %57, %58, %59}," + " %60," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x112x32_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x144x32 TN S32+=U8*S8 +struct MMA_64x144x32_S32U8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[72]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %77, 0;\n" + "wgmma.mma_async.sync.aligned.m64n144k32.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}," + "{%72, %73, %74, %75}," + " %76," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x144x32_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x144x32 TN S32+=U8*S8 +struct MMA_64x144x32_S32U8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[72]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %77, 0;\n" + "wgmma.mma_async.sync.aligned.m64n144k32.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}," + "{%72, %73, %74, %75}," + " %76," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x144x32_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x160x32 TN S32+=U8*S8 +struct MMA_64x160x32_S32U8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[80]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %85, 0;\n" + "wgmma.mma_async.sync.aligned.m64n160k32.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}," + "{%80, %81, %82, %83}," + " %84," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x160x32_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x160x32 TN S32+=U8*S8 +struct MMA_64x160x32_S32U8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[80]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %85, 0;\n" + "wgmma.mma_async.sync.aligned.m64n160k32.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}," + "{%80, %81, %82, %83}," + " %84," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x160x32_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x176x32 TN S32+=U8*S8 +struct MMA_64x176x32_S32U8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[88]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %93, 0;\n" + "wgmma.mma_async.sync.aligned.m64n176k32.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87}," + "{%88, %89, %90, %91}," + " %92," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x176x32_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x176x32 TN S32+=U8*S8 +struct MMA_64x176x32_S32U8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[88]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %93, 0;\n" + "wgmma.mma_async.sync.aligned.m64n176k32.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87}," + "{%88, %89, %90, %91}," + " %92," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x176x32_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x208x32 TN S32+=U8*S8 +struct MMA_64x208x32_S32U8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[104]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %109, 0;\n" + "wgmma.mma_async.sync.aligned.m64n208k32.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103}," + "{%104, %105, %106, %107}," + " %108," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x208x32_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x208x32 TN S32+=U8*S8 +struct MMA_64x208x32_S32U8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[104]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %109, 0;\n" + "wgmma.mma_async.sync.aligned.m64n208k32.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103}," + "{%104, %105, %106, %107}," + " %108," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x208x32_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x224x32 TN S32+=U8*S8 +struct MMA_64x224x32_S32U8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[112]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %117, 0;\n" + "wgmma.mma_async.sync.aligned.m64n224k32.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111}," + "{%112, %113, %114, %115}," + " %116," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x224x32_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x224x32 TN S32+=U8*S8 +struct MMA_64x224x32_S32U8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[112]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %117, 0;\n" + "wgmma.mma_async.sync.aligned.m64n224k32.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111}," + "{%112, %113, %114, %115}," + " %116," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x224x32_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x240x32 TN S32+=U8*S8 +struct MMA_64x240x32_S32U8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[120]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %125, 0;\n" + "wgmma.mma_async.sync.aligned.m64n240k32.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119}," + "{%120, %121, %122, %123}," + " %124," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x240x32_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x240x32 TN S32+=U8*S8 +struct MMA_64x240x32_S32U8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[120]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %125, 0;\n" + "wgmma.mma_async.sync.aligned.m64n240k32.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119}," + "{%120, %121, %122, %123}," + " %124," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x240x32_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x24x32 TN S32+=U8*U8 +struct MMA_64x24x32_S32U8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[12]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %14, 0;\n" + "wgmma.mma_async.sync.aligned.m64n24k32.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + " %12," + " %13," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x24x32_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x24x32 TN S32+=U8*U8 +struct MMA_64x24x32_S32U8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[12]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %14, 0;\n" + "wgmma.mma_async.sync.aligned.m64n24k32.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + " %12," + " %13," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x24x32_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x48x32 TN S32+=U8*U8 +struct MMA_64x48x32_S32U8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[24]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %26, 0;\n" + "wgmma.mma_async.sync.aligned.m64n48k32.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + " %24," + " %25," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x48x32_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x48x32 TN S32+=U8*U8 +struct MMA_64x48x32_S32U8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[24]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %26, 0;\n" + "wgmma.mma_async.sync.aligned.m64n48k32.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + " %24," + " %25," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x48x32_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x80x32 TN S32+=U8*U8 +struct MMA_64x80x32_S32U8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[40]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %42, 0;\n" + "wgmma.mma_async.sync.aligned.m64n80k32.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + " %40," + " %41," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x80x32_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x80x32 TN S32+=U8*U8 +struct MMA_64x80x32_S32U8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[40]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %42, 0;\n" + "wgmma.mma_async.sync.aligned.m64n80k32.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + " %40," + " %41," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x80x32_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x112x32 TN S32+=U8*U8 +struct MMA_64x112x32_S32U8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[56]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %58, 0;\n" + "wgmma.mma_async.sync.aligned.m64n112k32.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + " %56," + " %57," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x112x32_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x112x32 TN S32+=U8*U8 +struct MMA_64x112x32_S32U8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[56]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %58, 0;\n" + "wgmma.mma_async.sync.aligned.m64n112k32.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + " %56," + " %57," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x112x32_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x144x32 TN S32+=U8*U8 +struct MMA_64x144x32_S32U8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[72]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %74, 0;\n" + "wgmma.mma_async.sync.aligned.m64n144k32.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}," + " %72," + " %73," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x144x32_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x144x32 TN S32+=U8*U8 +struct MMA_64x144x32_S32U8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[72]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %74, 0;\n" + "wgmma.mma_async.sync.aligned.m64n144k32.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}," + " %72," + " %73," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x144x32_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x160x32 TN S32+=U8*U8 +struct MMA_64x160x32_S32U8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[80]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %82, 0;\n" + "wgmma.mma_async.sync.aligned.m64n160k32.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}," + " %80," + " %81," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x160x32_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x160x32 TN S32+=U8*U8 +struct MMA_64x160x32_S32U8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[80]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %82, 0;\n" + "wgmma.mma_async.sync.aligned.m64n160k32.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}," + " %80," + " %81," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x160x32_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x176x32 TN S32+=U8*U8 +struct MMA_64x176x32_S32U8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[88]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %90, 0;\n" + "wgmma.mma_async.sync.aligned.m64n176k32.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87}," + " %88," + " %89," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x176x32_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x176x32 TN S32+=U8*U8 +struct MMA_64x176x32_S32U8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[88]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %90, 0;\n" + "wgmma.mma_async.sync.aligned.m64n176k32.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87}," + " %88," + " %89," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x176x32_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x208x32 TN S32+=U8*U8 +struct MMA_64x208x32_S32U8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[104]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %106, 0;\n" + "wgmma.mma_async.sync.aligned.m64n208k32.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103}," + " %104," + " %105," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x208x32_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x208x32 TN S32+=U8*U8 +struct MMA_64x208x32_S32U8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[104]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %106, 0;\n" + "wgmma.mma_async.sync.aligned.m64n208k32.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103}," + " %104," + " %105," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x208x32_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x224x32 TN S32+=U8*U8 +struct MMA_64x224x32_S32U8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[112]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %114, 0;\n" + "wgmma.mma_async.sync.aligned.m64n224k32.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111}," + " %112," + " %113," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x224x32_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x224x32 TN S32+=U8*U8 +struct MMA_64x224x32_S32U8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[112]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %114, 0;\n" + "wgmma.mma_async.sync.aligned.m64n224k32.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111}," + " %112," + " %113," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x224x32_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x240x32 TN S32+=U8*U8 +struct MMA_64x240x32_S32U8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[120]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %122, 0;\n" + "wgmma.mma_async.sync.aligned.m64n240k32.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119}," + " %120," + " %121," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x240x32_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x240x32 TN S32+=U8*U8 +struct MMA_64x240x32_S32U8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[120]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %122, 0;\n" + "wgmma.mma_async.sync.aligned.m64n240k32.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119}," + " %120," + " %121," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x240x32_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x24x32 TN S32+=U8*U8 +struct MMA_64x24x32_S32U8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[12]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %17, 0;\n" + "wgmma.mma_async.sync.aligned.m64n24k32.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + "{%12, %13, %14, %15}," + " %16," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x24x32_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x24x32 TN S32+=U8*U8 +struct MMA_64x24x32_S32U8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[12]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %17, 0;\n" + "wgmma.mma_async.sync.aligned.m64n24k32.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + "{%12, %13, %14, %15}," + " %16," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x24x32_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x48x32 TN S32+=U8*U8 +struct MMA_64x48x32_S32U8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[24]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %29, 0;\n" + "wgmma.mma_async.sync.aligned.m64n48k32.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + "{%24, %25, %26, %27}," + " %28," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x48x32_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x48x32 TN S32+=U8*U8 +struct MMA_64x48x32_S32U8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[24]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %29, 0;\n" + "wgmma.mma_async.sync.aligned.m64n48k32.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + "{%24, %25, %26, %27}," + " %28," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x48x32_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x80x32 TN S32+=U8*U8 +struct MMA_64x80x32_S32U8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[40]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %45, 0;\n" + "wgmma.mma_async.sync.aligned.m64n80k32.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + "{%40, %41, %42, %43}," + " %44," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x80x32_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x80x32 TN S32+=U8*U8 +struct MMA_64x80x32_S32U8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[40]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %45, 0;\n" + "wgmma.mma_async.sync.aligned.m64n80k32.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + "{%40, %41, %42, %43}," + " %44," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x80x32_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x112x32 TN S32+=U8*U8 +struct MMA_64x112x32_S32U8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[56]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %61, 0;\n" + "wgmma.mma_async.sync.aligned.m64n112k32.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + "{%56, %57, %58, %59}," + " %60," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x112x32_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x112x32 TN S32+=U8*U8 +struct MMA_64x112x32_S32U8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[56]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %61, 0;\n" + "wgmma.mma_async.sync.aligned.m64n112k32.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + "{%56, %57, %58, %59}," + " %60," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x112x32_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x144x32 TN S32+=U8*U8 +struct MMA_64x144x32_S32U8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[72]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %77, 0;\n" + "wgmma.mma_async.sync.aligned.m64n144k32.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}," + "{%72, %73, %74, %75}," + " %76," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x144x32_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x144x32 TN S32+=U8*U8 +struct MMA_64x144x32_S32U8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[72]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %77, 0;\n" + "wgmma.mma_async.sync.aligned.m64n144k32.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}," + "{%72, %73, %74, %75}," + " %76," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x144x32_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x160x32 TN S32+=U8*U8 +struct MMA_64x160x32_S32U8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[80]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %85, 0;\n" + "wgmma.mma_async.sync.aligned.m64n160k32.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}," + "{%80, %81, %82, %83}," + " %84," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x160x32_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x160x32 TN S32+=U8*U8 +struct MMA_64x160x32_S32U8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[80]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %85, 0;\n" + "wgmma.mma_async.sync.aligned.m64n160k32.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}," + "{%80, %81, %82, %83}," + " %84," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x160x32_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x176x32 TN S32+=U8*U8 +struct MMA_64x176x32_S32U8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[88]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %93, 0;\n" + "wgmma.mma_async.sync.aligned.m64n176k32.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87}," + "{%88, %89, %90, %91}," + " %92," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x176x32_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x176x32 TN S32+=U8*U8 +struct MMA_64x176x32_S32U8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[88]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %93, 0;\n" + "wgmma.mma_async.sync.aligned.m64n176k32.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87}," + "{%88, %89, %90, %91}," + " %92," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x176x32_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x208x32 TN S32+=U8*U8 +struct MMA_64x208x32_S32U8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[104]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %109, 0;\n" + "wgmma.mma_async.sync.aligned.m64n208k32.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103}," + "{%104, %105, %106, %107}," + " %108," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x208x32_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x208x32 TN S32+=U8*U8 +struct MMA_64x208x32_S32U8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[104]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %109, 0;\n" + "wgmma.mma_async.sync.aligned.m64n208k32.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103}," + "{%104, %105, %106, %107}," + " %108," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x208x32_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x224x32 TN S32+=U8*U8 +struct MMA_64x224x32_S32U8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[112]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %117, 0;\n" + "wgmma.mma_async.sync.aligned.m64n224k32.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111}," + "{%112, %113, %114, %115}," + " %116," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x224x32_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x224x32 TN S32+=U8*U8 +struct MMA_64x224x32_S32U8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[112]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %117, 0;\n" + "wgmma.mma_async.sync.aligned.m64n224k32.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111}," + "{%112, %113, %114, %115}," + " %116," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x224x32_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x240x32 TN S32+=U8*U8 +struct MMA_64x240x32_S32U8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[120]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %125, 0;\n" + "wgmma.mma_async.sync.aligned.m64n240k32.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119}," + "{%120, %121, %122, %123}," + " %124," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x240x32_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x240x32 TN S32+=U8*U8 +struct MMA_64x240x32_S32U8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[120]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %125, 0;\n" + "wgmma.mma_async.sync.aligned.m64n240k32.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119}," + "{%120, %121, %122, %123}," + " %124," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x240x32_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x24x32 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x24x32_F16E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[6]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %8, 0;\n" + "wgmma.mma_async.sync.aligned.m64n24k32.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5}," + " %6," + " %7," + " p, %9, %10;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x24x32_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x24x32 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x24x32_F16E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[6]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %11, 0;\n" + "wgmma.mma_async.sync.aligned.m64n24k32.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5}," + "{%6, %7, %8, %9}," + " %10," + " p, %12, %13;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x24x32_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x24x32 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x24x32_F32E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[12]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %14, 0;\n" + "wgmma.mma_async.sync.aligned.m64n24k32.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + " %12," + " %13," + " p, %15, %16;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x24x32_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x24x32 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x24x32_F32E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[12]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %17, 0;\n" + "wgmma.mma_async.sync.aligned.m64n24k32.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + "{%12, %13, %14, %15}," + " %16," + " p, %18, %19;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x24x32_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x40x32 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x40x32_F16E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[10]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %12, 0;\n" + "wgmma.mma_async.sync.aligned.m64n40k32.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9}," + " %10," + " %11," + " p, %13, %14;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x40x32_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x40x32 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x40x32_F16E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[10]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %15, 0;\n" + "wgmma.mma_async.sync.aligned.m64n40k32.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9}," + "{%10, %11, %12, %13}," + " %14," + " p, %16, %17;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x40x32_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x40x32 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x40x32_F32E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[20]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %22, 0;\n" + "wgmma.mma_async.sync.aligned.m64n40k32.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19}," + " %20," + " %21," + " p, %23, %24;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x40x32_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x40x32 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x40x32_F32E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[20]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %25, 0;\n" + "wgmma.mma_async.sync.aligned.m64n40k32.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19}," + "{%20, %21, %22, %23}," + " %24," + " p, %26, %27;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x40x32_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x48x32 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x48x32_F16E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[12]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %14, 0;\n" + "wgmma.mma_async.sync.aligned.m64n48k32.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + " %12," + " %13," + " p, %15, %16;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x48x32_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x48x32 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x48x32_F16E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[12]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %17, 0;\n" + "wgmma.mma_async.sync.aligned.m64n48k32.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + "{%12, %13, %14, %15}," + " %16," + " p, %18, %19;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x48x32_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x48x32 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x48x32_F32E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[24]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %26, 0;\n" + "wgmma.mma_async.sync.aligned.m64n48k32.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + " %24," + " %25," + " p, %27, %28;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x48x32_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x48x32 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x48x32_F32E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[24]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %29, 0;\n" + "wgmma.mma_async.sync.aligned.m64n48k32.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + "{%24, %25, %26, %27}," + " %28," + " p, %30, %31;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x48x32_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x56x32 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x56x32_F16E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[14]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %16, 0;\n" + "wgmma.mma_async.sync.aligned.m64n56k32.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13}," + " %14," + " %15," + " p, %17, %18;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x56x32_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x56x32 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x56x32_F16E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[14]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %19, 0;\n" + "wgmma.mma_async.sync.aligned.m64n56k32.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13}," + "{%14, %15, %16, %17}," + " %18," + " p, %20, %21;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x56x32_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x56x32 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x56x32_F32E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[28]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %30, 0;\n" + "wgmma.mma_async.sync.aligned.m64n56k32.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27}," + " %28," + " %29," + " p, %31, %32;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x56x32_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x56x32 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x56x32_F32E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[28]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %33, 0;\n" + "wgmma.mma_async.sync.aligned.m64n56k32.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27}," + "{%28, %29, %30, %31}," + " %32," + " p, %34, %35;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x56x32_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x72x32 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x72x32_F16E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[18]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %20, 0;\n" + "wgmma.mma_async.sync.aligned.m64n72k32.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17}," + " %18," + " %19," + " p, %21, %22;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x72x32_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x72x32 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x72x32_F16E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[18]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %23, 0;\n" + "wgmma.mma_async.sync.aligned.m64n72k32.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17}," + "{%18, %19, %20, %21}," + " %22," + " p, %24, %25;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x72x32_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x72x32 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x72x32_F32E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[36]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %38, 0;\n" + "wgmma.mma_async.sync.aligned.m64n72k32.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35}," + " %36," + " %37," + " p, %39, %40;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x72x32_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x72x32 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x72x32_F32E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[36]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %41, 0;\n" + "wgmma.mma_async.sync.aligned.m64n72k32.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35}," + "{%36, %37, %38, %39}," + " %40," + " p, %42, %43;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x72x32_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x80x32 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x80x32_F16E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[20]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %22, 0;\n" + "wgmma.mma_async.sync.aligned.m64n80k32.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19}," + " %20," + " %21," + " p, %23, %24;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x80x32_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x80x32 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x80x32_F16E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[20]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %25, 0;\n" + "wgmma.mma_async.sync.aligned.m64n80k32.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19}," + "{%20, %21, %22, %23}," + " %24," + " p, %26, %27;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x80x32_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x80x32 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x80x32_F32E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[40]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %42, 0;\n" + "wgmma.mma_async.sync.aligned.m64n80k32.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + " %40," + " %41," + " p, %43, %44;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x80x32_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x80x32 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x80x32_F32E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[40]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %45, 0;\n" + "wgmma.mma_async.sync.aligned.m64n80k32.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + "{%40, %41, %42, %43}," + " %44," + " p, %46, %47;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x80x32_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x88x32 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x88x32_F16E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[22]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %24, 0;\n" + "wgmma.mma_async.sync.aligned.m64n88k32.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21}," + " %22," + " %23," + " p, %25, %26;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x88x32_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x88x32 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x88x32_F16E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[22]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %27, 0;\n" + "wgmma.mma_async.sync.aligned.m64n88k32.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21}," + "{%22, %23, %24, %25}," + " %26," + " p, %28, %29;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x88x32_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x88x32 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x88x32_F32E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[44]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %46, 0;\n" + "wgmma.mma_async.sync.aligned.m64n88k32.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43}," + " %44," + " %45," + " p, %47, %48;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x88x32_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x88x32 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x88x32_F32E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[44]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %49, 0;\n" + "wgmma.mma_async.sync.aligned.m64n88k32.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43}," + "{%44, %45, %46, %47}," + " %48," + " p, %50, %51;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x88x32_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x104x32 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x104x32_F16E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[26]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %28, 0;\n" + "wgmma.mma_async.sync.aligned.m64n104k32.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25}," + " %26," + " %27," + " p, %29, %30;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x104x32_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x104x32 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x104x32_F16E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[26]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %31, 0;\n" + "wgmma.mma_async.sync.aligned.m64n104k32.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25}," + "{%26, %27, %28, %29}," + " %30," + " p, %32, %33;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x104x32_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x104x32 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x104x32_F32E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[52]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %54, 0;\n" + "wgmma.mma_async.sync.aligned.m64n104k32.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51}," + " %52," + " %53," + " p, %55, %56;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x104x32_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x104x32 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x104x32_F32E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[52]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %57, 0;\n" + "wgmma.mma_async.sync.aligned.m64n104k32.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51}," + "{%52, %53, %54, %55}," + " %56," + " p, %58, %59;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x104x32_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x112x32 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x112x32_F16E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[28]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %30, 0;\n" + "wgmma.mma_async.sync.aligned.m64n112k32.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27}," + " %28," + " %29," + " p, %31, %32;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x112x32_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x112x32 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x112x32_F16E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[28]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %33, 0;\n" + "wgmma.mma_async.sync.aligned.m64n112k32.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27}," + "{%28, %29, %30, %31}," + " %32," + " p, %34, %35;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x112x32_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x112x32 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x112x32_F32E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[56]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %58, 0;\n" + "wgmma.mma_async.sync.aligned.m64n112k32.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + " %56," + " %57," + " p, %59, %60;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x112x32_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x112x32 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x112x32_F32E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[56]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %61, 0;\n" + "wgmma.mma_async.sync.aligned.m64n112k32.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + "{%56, %57, %58, %59}," + " %60," + " p, %62, %63;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x112x32_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x120x32 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x120x32_F16E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[30]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %32, 0;\n" + "wgmma.mma_async.sync.aligned.m64n120k32.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29}," + " %30," + " %31," + " p, %33, %34;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x120x32_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x120x32 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x120x32_F16E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[30]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %35, 0;\n" + "wgmma.mma_async.sync.aligned.m64n120k32.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29}," + "{%30, %31, %32, %33}," + " %34," + " p, %36, %37;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x120x32_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x120x32 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x120x32_F32E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[60]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %62, 0;\n" + "wgmma.mma_async.sync.aligned.m64n120k32.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59}," + " %60," + " %61," + " p, %63, %64;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x120x32_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x120x32 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x120x32_F32E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[60]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %65, 0;\n" + "wgmma.mma_async.sync.aligned.m64n120k32.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59}," + "{%60, %61, %62, %63}," + " %64," + " p, %66, %67;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x120x32_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x136x32 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x136x32_F16E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[34]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %36, 0;\n" + "wgmma.mma_async.sync.aligned.m64n136k32.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33}," + " %34," + " %35," + " p, %37, %38;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x136x32_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x136x32 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x136x32_F16E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[34]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %39, 0;\n" + "wgmma.mma_async.sync.aligned.m64n136k32.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33}," + "{%34, %35, %36, %37}," + " %38," + " p, %40, %41;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x136x32_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x136x32 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x136x32_F32E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[68]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %70, 0;\n" + "wgmma.mma_async.sync.aligned.m64n136k32.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67}," + " %68," + " %69," + " p, %71, %72;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x136x32_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x136x32 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x136x32_F32E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[68]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %73, 0;\n" + "wgmma.mma_async.sync.aligned.m64n136k32.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67}," + "{%68, %69, %70, %71}," + " %72," + " p, %74, %75;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x136x32_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x144x32 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x144x32_F16E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[36]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %38, 0;\n" + "wgmma.mma_async.sync.aligned.m64n144k32.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35}," + " %36," + " %37," + " p, %39, %40;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x144x32_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x144x32 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x144x32_F16E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[36]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %41, 0;\n" + "wgmma.mma_async.sync.aligned.m64n144k32.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35}," + "{%36, %37, %38, %39}," + " %40," + " p, %42, %43;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x144x32_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x144x32 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x144x32_F32E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[72]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %74, 0;\n" + "wgmma.mma_async.sync.aligned.m64n144k32.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}," + " %72," + " %73," + " p, %75, %76;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x144x32_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x144x32 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x144x32_F32E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[72]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %77, 0;\n" + "wgmma.mma_async.sync.aligned.m64n144k32.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}," + "{%72, %73, %74, %75}," + " %76," + " p, %78, %79;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x144x32_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x152x32 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x152x32_F16E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[38]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %40, 0;\n" + "wgmma.mma_async.sync.aligned.m64n152k32.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37}," + " %38," + " %39," + " p, %41, %42;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x152x32_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x152x32 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x152x32_F16E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[38]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %43, 0;\n" + "wgmma.mma_async.sync.aligned.m64n152k32.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37}," + "{%38, %39, %40, %41}," + " %42," + " p, %44, %45;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x152x32_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x152x32 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x152x32_F32E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[76]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %78, 0;\n" + "wgmma.mma_async.sync.aligned.m64n152k32.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75}," + " %76," + " %77," + " p, %79, %80;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x152x32_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x152x32 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x152x32_F32E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[76]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %81, 0;\n" + "wgmma.mma_async.sync.aligned.m64n152k32.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75}," + "{%76, %77, %78, %79}," + " %80," + " p, %82, %83;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x152x32_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x160x32 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x160x32_F16E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[40]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %42, 0;\n" + "wgmma.mma_async.sync.aligned.m64n160k32.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + " %40," + " %41," + " p, %43, %44;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x160x32_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x160x32 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x160x32_F16E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[40]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %45, 0;\n" + "wgmma.mma_async.sync.aligned.m64n160k32.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + "{%40, %41, %42, %43}," + " %44," + " p, %46, %47;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x160x32_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x160x32 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x160x32_F32E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[80]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %82, 0;\n" + "wgmma.mma_async.sync.aligned.m64n160k32.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}," + " %80," + " %81," + " p, %83, %84;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x160x32_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x160x32 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x160x32_F32E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[80]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %85, 0;\n" + "wgmma.mma_async.sync.aligned.m64n160k32.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}," + "{%80, %81, %82, %83}," + " %84," + " p, %86, %87;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x160x32_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x168x32 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x168x32_F16E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[42]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %44, 0;\n" + "wgmma.mma_async.sync.aligned.m64n168k32.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41}," + " %42," + " %43," + " p, %45, %46;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x168x32_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x168x32 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x168x32_F16E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[42]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %47, 0;\n" + "wgmma.mma_async.sync.aligned.m64n168k32.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41}," + "{%42, %43, %44, %45}," + " %46," + " p, %48, %49;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x168x32_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x168x32 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x168x32_F32E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[84]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %86, 0;\n" + "wgmma.mma_async.sync.aligned.m64n168k32.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83}," + " %84," + " %85," + " p, %87, %88;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x168x32_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x168x32 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x168x32_F32E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[84]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %89, 0;\n" + "wgmma.mma_async.sync.aligned.m64n168k32.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83}," + "{%84, %85, %86, %87}," + " %88," + " p, %90, %91;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x168x32_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x176x32 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x176x32_F16E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[44]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %46, 0;\n" + "wgmma.mma_async.sync.aligned.m64n176k32.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43}," + " %44," + " %45," + " p, %47, %48;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x176x32_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x176x32 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x176x32_F16E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[44]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %49, 0;\n" + "wgmma.mma_async.sync.aligned.m64n176k32.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43}," + "{%44, %45, %46, %47}," + " %48," + " p, %50, %51;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x176x32_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x176x32 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x176x32_F32E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[88]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %90, 0;\n" + "wgmma.mma_async.sync.aligned.m64n176k32.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87}," + " %88," + " %89," + " p, %91, %92;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x176x32_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x176x32 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x176x32_F32E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[88]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %93, 0;\n" + "wgmma.mma_async.sync.aligned.m64n176k32.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87}," + "{%88, %89, %90, %91}," + " %92," + " p, %94, %95;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x176x32_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x184x32 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x184x32_F16E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[46]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %48, 0;\n" + "wgmma.mma_async.sync.aligned.m64n184k32.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45}," + " %46," + " %47," + " p, %49, %50;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x184x32_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x184x32 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x184x32_F16E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[46]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %51, 0;\n" + "wgmma.mma_async.sync.aligned.m64n184k32.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45}," + "{%46, %47, %48, %49}," + " %50," + " p, %52, %53;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x184x32_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x184x32 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x184x32_F32E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[92]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + float & d88, float & d89, float & d90, float & d91, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %94, 0;\n" + "wgmma.mma_async.sync.aligned.m64n184k32.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91}," + " %92," + " %93," + " p, %95, %96;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), + "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x184x32_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x184x32 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x184x32_F32E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[92]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + float & d88, float & d89, float & d90, float & d91, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %97, 0;\n" + "wgmma.mma_async.sync.aligned.m64n184k32.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91}," + "{%92, %93, %94, %95}," + " %96," + " p, %98, %99;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), + "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x184x32_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x200x32 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x200x32_F16E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[50]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %52, 0;\n" + "wgmma.mma_async.sync.aligned.m64n200k32.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49}," + " %50," + " %51," + " p, %53, %54;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x200x32_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x200x32 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x200x32_F16E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[50]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %55, 0;\n" + "wgmma.mma_async.sync.aligned.m64n200k32.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49}," + "{%50, %51, %52, %53}," + " %54," + " p, %56, %57;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x200x32_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x200x32 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x200x32_F32E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[100]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %102, 0;\n" + "wgmma.mma_async.sync.aligned.m64n200k32.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99}," + " %100," + " %101," + " p, %103, %104;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x200x32_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x200x32 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x200x32_F32E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[100]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %105, 0;\n" + "wgmma.mma_async.sync.aligned.m64n200k32.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99}," + "{%100, %101, %102, %103}," + " %104," + " p, %106, %107;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x200x32_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x208x32 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x208x32_F16E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[52]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %54, 0;\n" + "wgmma.mma_async.sync.aligned.m64n208k32.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51}," + " %52," + " %53," + " p, %55, %56;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x208x32_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x208x32 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x208x32_F16E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[52]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %57, 0;\n" + "wgmma.mma_async.sync.aligned.m64n208k32.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51}," + "{%52, %53, %54, %55}," + " %56," + " p, %58, %59;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x208x32_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x208x32 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x208x32_F32E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[104]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %106, 0;\n" + "wgmma.mma_async.sync.aligned.m64n208k32.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103}," + " %104," + " %105," + " p, %107, %108;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x208x32_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x208x32 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x208x32_F32E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[104]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %109, 0;\n" + "wgmma.mma_async.sync.aligned.m64n208k32.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103}," + "{%104, %105, %106, %107}," + " %108," + " p, %110, %111;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x208x32_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x216x32 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x216x32_F16E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[54]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %56, 0;\n" + "wgmma.mma_async.sync.aligned.m64n216k32.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53}," + " %54," + " %55," + " p, %57, %58;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x216x32_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x216x32 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x216x32_F16E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[54]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %59, 0;\n" + "wgmma.mma_async.sync.aligned.m64n216k32.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53}," + "{%54, %55, %56, %57}," + " %58," + " p, %60, %61;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x216x32_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x216x32 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x216x32_F32E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[108]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %110, 0;\n" + "wgmma.mma_async.sync.aligned.m64n216k32.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107}," + " %108," + " %109," + " p, %111, %112;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x216x32_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x216x32 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x216x32_F32E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[108]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %113, 0;\n" + "wgmma.mma_async.sync.aligned.m64n216k32.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107}," + "{%108, %109, %110, %111}," + " %112," + " p, %114, %115;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x216x32_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x224x32 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x224x32_F16E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[56]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %58, 0;\n" + "wgmma.mma_async.sync.aligned.m64n224k32.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + " %56," + " %57," + " p, %59, %60;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x224x32_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x224x32 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x224x32_F16E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[56]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %61, 0;\n" + "wgmma.mma_async.sync.aligned.m64n224k32.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + "{%56, %57, %58, %59}," + " %60," + " p, %62, %63;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x224x32_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x224x32 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x224x32_F32E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[112]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %114, 0;\n" + "wgmma.mma_async.sync.aligned.m64n224k32.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111}," + " %112," + " %113," + " p, %115, %116;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x224x32_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x224x32 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x224x32_F32E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[112]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %117, 0;\n" + "wgmma.mma_async.sync.aligned.m64n224k32.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111}," + "{%112, %113, %114, %115}," + " %116," + " p, %118, %119;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x224x32_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x232x32 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x232x32_F16E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[58]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %60, 0;\n" + "wgmma.mma_async.sync.aligned.m64n232k32.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57}," + " %58," + " %59," + " p, %61, %62;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x232x32_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x232x32 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x232x32_F16E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[58]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %63, 0;\n" + "wgmma.mma_async.sync.aligned.m64n232k32.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57}," + "{%58, %59, %60, %61}," + " %62," + " p, %64, %65;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x232x32_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x232x32 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x232x32_F32E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[116]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %118, 0;\n" + "wgmma.mma_async.sync.aligned.m64n232k32.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115}," + " %116," + " %117," + " p, %119, %120;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x232x32_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x232x32 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x232x32_F32E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[116]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %121, 0;\n" + "wgmma.mma_async.sync.aligned.m64n232k32.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115}," + "{%116, %117, %118, %119}," + " %120," + " p, %122, %123;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x232x32_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x240x32 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x240x32_F16E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[60]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %62, 0;\n" + "wgmma.mma_async.sync.aligned.m64n240k32.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59}," + " %60," + " %61," + " p, %63, %64;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x240x32_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x240x32 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x240x32_F16E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[60]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %65, 0;\n" + "wgmma.mma_async.sync.aligned.m64n240k32.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59}," + "{%60, %61, %62, %63}," + " %64," + " p, %66, %67;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x240x32_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x240x32 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x240x32_F32E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[120]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %122, 0;\n" + "wgmma.mma_async.sync.aligned.m64n240k32.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119}," + " %120," + " %121," + " p, %123, %124;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x240x32_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x240x32 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x240x32_F32E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[120]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %125, 0;\n" + "wgmma.mma_async.sync.aligned.m64n240k32.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119}," + "{%120, %121, %122, %123}," + " %124," + " p, %126, %127;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x240x32_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x248x32 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x248x32_F16E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[62]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %64, 0;\n" + "wgmma.mma_async.sync.aligned.m64n248k32.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61}," + " %62," + " %63," + " p, %65, %66;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x248x32_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x248x32 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x248x32_F16E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[62]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %67, 0;\n" + "wgmma.mma_async.sync.aligned.m64n248k32.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61}," + "{%62, %63, %64, %65}," + " %66," + " p, %68, %69;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x248x32_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x248x32 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x248x32_F32E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[124]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + float & d120, float & d121, float & d122, float & d123, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %126, 0;\n" + "wgmma.mma_async.sync.aligned.m64n248k32.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123}," + " %124," + " %125," + " p, %127, %128;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), + "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x248x32_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x248x32 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x248x32_F32E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[124]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + float & d120, float & d121, float & d122, float & d123, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %129, 0;\n" + "wgmma.mma_async.sync.aligned.m64n248k32.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123}," + "{%124, %125, %126, %127}," + " %128," + " p, %130, %131;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), + "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x248x32_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x24x32 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x24x32_F16E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[6]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %8, 0;\n" + "wgmma.mma_async.sync.aligned.m64n24k32.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5}," + " %6," + " %7," + " p, %9, %10;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x24x32_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x24x32 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x24x32_F16E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[6]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %11, 0;\n" + "wgmma.mma_async.sync.aligned.m64n24k32.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5}," + "{%6, %7, %8, %9}," + " %10," + " p, %12, %13;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x24x32_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x24x32 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x24x32_F32E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[12]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %14, 0;\n" + "wgmma.mma_async.sync.aligned.m64n24k32.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + " %12," + " %13," + " p, %15, %16;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x24x32_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x24x32 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x24x32_F32E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[12]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %17, 0;\n" + "wgmma.mma_async.sync.aligned.m64n24k32.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + "{%12, %13, %14, %15}," + " %16," + " p, %18, %19;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x24x32_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x40x32 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x40x32_F16E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[10]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %12, 0;\n" + "wgmma.mma_async.sync.aligned.m64n40k32.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9}," + " %10," + " %11," + " p, %13, %14;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x40x32_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x40x32 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x40x32_F16E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[10]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %15, 0;\n" + "wgmma.mma_async.sync.aligned.m64n40k32.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9}," + "{%10, %11, %12, %13}," + " %14," + " p, %16, %17;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x40x32_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x40x32 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x40x32_F32E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[20]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %22, 0;\n" + "wgmma.mma_async.sync.aligned.m64n40k32.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19}," + " %20," + " %21," + " p, %23, %24;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x40x32_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x40x32 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x40x32_F32E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[20]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %25, 0;\n" + "wgmma.mma_async.sync.aligned.m64n40k32.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19}," + "{%20, %21, %22, %23}," + " %24," + " p, %26, %27;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x40x32_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x48x32 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x48x32_F16E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[12]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %14, 0;\n" + "wgmma.mma_async.sync.aligned.m64n48k32.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + " %12," + " %13," + " p, %15, %16;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x48x32_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x48x32 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x48x32_F16E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[12]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %17, 0;\n" + "wgmma.mma_async.sync.aligned.m64n48k32.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + "{%12, %13, %14, %15}," + " %16," + " p, %18, %19;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x48x32_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x48x32 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x48x32_F32E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[24]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %26, 0;\n" + "wgmma.mma_async.sync.aligned.m64n48k32.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + " %24," + " %25," + " p, %27, %28;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x48x32_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x48x32 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x48x32_F32E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[24]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %29, 0;\n" + "wgmma.mma_async.sync.aligned.m64n48k32.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + "{%24, %25, %26, %27}," + " %28," + " p, %30, %31;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x48x32_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x56x32 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x56x32_F16E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[14]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %16, 0;\n" + "wgmma.mma_async.sync.aligned.m64n56k32.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13}," + " %14," + " %15," + " p, %17, %18;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x56x32_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x56x32 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x56x32_F16E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[14]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %19, 0;\n" + "wgmma.mma_async.sync.aligned.m64n56k32.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13}," + "{%14, %15, %16, %17}," + " %18," + " p, %20, %21;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x56x32_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x56x32 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x56x32_F32E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[28]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %30, 0;\n" + "wgmma.mma_async.sync.aligned.m64n56k32.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27}," + " %28," + " %29," + " p, %31, %32;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x56x32_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x56x32 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x56x32_F32E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[28]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %33, 0;\n" + "wgmma.mma_async.sync.aligned.m64n56k32.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27}," + "{%28, %29, %30, %31}," + " %32," + " p, %34, %35;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x56x32_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x72x32 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x72x32_F16E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[18]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %20, 0;\n" + "wgmma.mma_async.sync.aligned.m64n72k32.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17}," + " %18," + " %19," + " p, %21, %22;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x72x32_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x72x32 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x72x32_F16E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[18]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %23, 0;\n" + "wgmma.mma_async.sync.aligned.m64n72k32.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17}," + "{%18, %19, %20, %21}," + " %22," + " p, %24, %25;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x72x32_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x72x32 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x72x32_F32E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[36]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %38, 0;\n" + "wgmma.mma_async.sync.aligned.m64n72k32.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35}," + " %36," + " %37," + " p, %39, %40;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x72x32_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x72x32 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x72x32_F32E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[36]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %41, 0;\n" + "wgmma.mma_async.sync.aligned.m64n72k32.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35}," + "{%36, %37, %38, %39}," + " %40," + " p, %42, %43;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x72x32_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x80x32 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x80x32_F16E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[20]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %22, 0;\n" + "wgmma.mma_async.sync.aligned.m64n80k32.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19}," + " %20," + " %21," + " p, %23, %24;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x80x32_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x80x32 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x80x32_F16E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[20]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %25, 0;\n" + "wgmma.mma_async.sync.aligned.m64n80k32.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19}," + "{%20, %21, %22, %23}," + " %24," + " p, %26, %27;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x80x32_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x80x32 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x80x32_F32E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[40]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %42, 0;\n" + "wgmma.mma_async.sync.aligned.m64n80k32.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + " %40," + " %41," + " p, %43, %44;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x80x32_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x80x32 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x80x32_F32E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[40]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %45, 0;\n" + "wgmma.mma_async.sync.aligned.m64n80k32.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + "{%40, %41, %42, %43}," + " %44," + " p, %46, %47;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x80x32_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x88x32 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x88x32_F16E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[22]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %24, 0;\n" + "wgmma.mma_async.sync.aligned.m64n88k32.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21}," + " %22," + " %23," + " p, %25, %26;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x88x32_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x88x32 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x88x32_F16E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[22]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %27, 0;\n" + "wgmma.mma_async.sync.aligned.m64n88k32.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21}," + "{%22, %23, %24, %25}," + " %26," + " p, %28, %29;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x88x32_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x88x32 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x88x32_F32E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[44]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %46, 0;\n" + "wgmma.mma_async.sync.aligned.m64n88k32.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43}," + " %44," + " %45," + " p, %47, %48;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x88x32_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x88x32 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x88x32_F32E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[44]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %49, 0;\n" + "wgmma.mma_async.sync.aligned.m64n88k32.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43}," + "{%44, %45, %46, %47}," + " %48," + " p, %50, %51;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x88x32_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x104x32 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x104x32_F16E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[26]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %28, 0;\n" + "wgmma.mma_async.sync.aligned.m64n104k32.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25}," + " %26," + " %27," + " p, %29, %30;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x104x32_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x104x32 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x104x32_F16E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[26]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %31, 0;\n" + "wgmma.mma_async.sync.aligned.m64n104k32.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25}," + "{%26, %27, %28, %29}," + " %30," + " p, %32, %33;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x104x32_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x104x32 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x104x32_F32E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[52]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %54, 0;\n" + "wgmma.mma_async.sync.aligned.m64n104k32.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51}," + " %52," + " %53," + " p, %55, %56;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x104x32_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x104x32 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x104x32_F32E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[52]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %57, 0;\n" + "wgmma.mma_async.sync.aligned.m64n104k32.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51}," + "{%52, %53, %54, %55}," + " %56," + " p, %58, %59;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x104x32_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x112x32 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x112x32_F16E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[28]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %30, 0;\n" + "wgmma.mma_async.sync.aligned.m64n112k32.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27}," + " %28," + " %29," + " p, %31, %32;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x112x32_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x112x32 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x112x32_F16E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[28]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %33, 0;\n" + "wgmma.mma_async.sync.aligned.m64n112k32.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27}," + "{%28, %29, %30, %31}," + " %32," + " p, %34, %35;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x112x32_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x112x32 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x112x32_F32E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[56]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %58, 0;\n" + "wgmma.mma_async.sync.aligned.m64n112k32.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + " %56," + " %57," + " p, %59, %60;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x112x32_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x112x32 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x112x32_F32E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[56]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %61, 0;\n" + "wgmma.mma_async.sync.aligned.m64n112k32.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + "{%56, %57, %58, %59}," + " %60," + " p, %62, %63;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x112x32_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x120x32 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x120x32_F16E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[30]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %32, 0;\n" + "wgmma.mma_async.sync.aligned.m64n120k32.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29}," + " %30," + " %31," + " p, %33, %34;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x120x32_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x120x32 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x120x32_F16E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[30]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %35, 0;\n" + "wgmma.mma_async.sync.aligned.m64n120k32.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29}," + "{%30, %31, %32, %33}," + " %34," + " p, %36, %37;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x120x32_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x120x32 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x120x32_F32E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[60]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %62, 0;\n" + "wgmma.mma_async.sync.aligned.m64n120k32.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59}," + " %60," + " %61," + " p, %63, %64;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x120x32_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x120x32 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x120x32_F32E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[60]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %65, 0;\n" + "wgmma.mma_async.sync.aligned.m64n120k32.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59}," + "{%60, %61, %62, %63}," + " %64," + " p, %66, %67;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x120x32_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x136x32 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x136x32_F16E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[34]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %36, 0;\n" + "wgmma.mma_async.sync.aligned.m64n136k32.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33}," + " %34," + " %35," + " p, %37, %38;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x136x32_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x136x32 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x136x32_F16E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[34]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %39, 0;\n" + "wgmma.mma_async.sync.aligned.m64n136k32.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33}," + "{%34, %35, %36, %37}," + " %38," + " p, %40, %41;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x136x32_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x136x32 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x136x32_F32E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[68]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %70, 0;\n" + "wgmma.mma_async.sync.aligned.m64n136k32.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67}," + " %68," + " %69," + " p, %71, %72;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x136x32_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x136x32 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x136x32_F32E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[68]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %73, 0;\n" + "wgmma.mma_async.sync.aligned.m64n136k32.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67}," + "{%68, %69, %70, %71}," + " %72," + " p, %74, %75;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x136x32_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x144x32 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x144x32_F16E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[36]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %38, 0;\n" + "wgmma.mma_async.sync.aligned.m64n144k32.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35}," + " %36," + " %37," + " p, %39, %40;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x144x32_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x144x32 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x144x32_F16E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[36]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %41, 0;\n" + "wgmma.mma_async.sync.aligned.m64n144k32.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35}," + "{%36, %37, %38, %39}," + " %40," + " p, %42, %43;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x144x32_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x144x32 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x144x32_F32E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[72]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %74, 0;\n" + "wgmma.mma_async.sync.aligned.m64n144k32.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}," + " %72," + " %73," + " p, %75, %76;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x144x32_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x144x32 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x144x32_F32E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[72]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %77, 0;\n" + "wgmma.mma_async.sync.aligned.m64n144k32.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}," + "{%72, %73, %74, %75}," + " %76," + " p, %78, %79;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x144x32_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x152x32 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x152x32_F16E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[38]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %40, 0;\n" + "wgmma.mma_async.sync.aligned.m64n152k32.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37}," + " %38," + " %39," + " p, %41, %42;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x152x32_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x152x32 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x152x32_F16E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[38]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %43, 0;\n" + "wgmma.mma_async.sync.aligned.m64n152k32.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37}," + "{%38, %39, %40, %41}," + " %42," + " p, %44, %45;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x152x32_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x152x32 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x152x32_F32E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[76]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %78, 0;\n" + "wgmma.mma_async.sync.aligned.m64n152k32.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75}," + " %76," + " %77," + " p, %79, %80;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x152x32_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x152x32 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x152x32_F32E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[76]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %81, 0;\n" + "wgmma.mma_async.sync.aligned.m64n152k32.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75}," + "{%76, %77, %78, %79}," + " %80," + " p, %82, %83;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x152x32_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x160x32 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x160x32_F16E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[40]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %42, 0;\n" + "wgmma.mma_async.sync.aligned.m64n160k32.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + " %40," + " %41," + " p, %43, %44;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x160x32_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x160x32 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x160x32_F16E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[40]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %45, 0;\n" + "wgmma.mma_async.sync.aligned.m64n160k32.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + "{%40, %41, %42, %43}," + " %44," + " p, %46, %47;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x160x32_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x160x32 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x160x32_F32E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[80]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %82, 0;\n" + "wgmma.mma_async.sync.aligned.m64n160k32.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}," + " %80," + " %81," + " p, %83, %84;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x160x32_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x160x32 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x160x32_F32E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[80]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %85, 0;\n" + "wgmma.mma_async.sync.aligned.m64n160k32.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}," + "{%80, %81, %82, %83}," + " %84," + " p, %86, %87;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x160x32_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x168x32 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x168x32_F16E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[42]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %44, 0;\n" + "wgmma.mma_async.sync.aligned.m64n168k32.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41}," + " %42," + " %43," + " p, %45, %46;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x168x32_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x168x32 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x168x32_F16E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[42]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %47, 0;\n" + "wgmma.mma_async.sync.aligned.m64n168k32.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41}," + "{%42, %43, %44, %45}," + " %46," + " p, %48, %49;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x168x32_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x168x32 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x168x32_F32E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[84]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %86, 0;\n" + "wgmma.mma_async.sync.aligned.m64n168k32.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83}," + " %84," + " %85," + " p, %87, %88;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x168x32_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x168x32 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x168x32_F32E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[84]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %89, 0;\n" + "wgmma.mma_async.sync.aligned.m64n168k32.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83}," + "{%84, %85, %86, %87}," + " %88," + " p, %90, %91;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x168x32_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x176x32 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x176x32_F16E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[44]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %46, 0;\n" + "wgmma.mma_async.sync.aligned.m64n176k32.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43}," + " %44," + " %45," + " p, %47, %48;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x176x32_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x176x32 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x176x32_F16E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[44]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %49, 0;\n" + "wgmma.mma_async.sync.aligned.m64n176k32.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43}," + "{%44, %45, %46, %47}," + " %48," + " p, %50, %51;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x176x32_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x176x32 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x176x32_F32E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[88]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %90, 0;\n" + "wgmma.mma_async.sync.aligned.m64n176k32.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87}," + " %88," + " %89," + " p, %91, %92;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x176x32_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x176x32 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x176x32_F32E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[88]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %93, 0;\n" + "wgmma.mma_async.sync.aligned.m64n176k32.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87}," + "{%88, %89, %90, %91}," + " %92," + " p, %94, %95;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x176x32_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x184x32 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x184x32_F16E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[46]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %48, 0;\n" + "wgmma.mma_async.sync.aligned.m64n184k32.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45}," + " %46," + " %47," + " p, %49, %50;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x184x32_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x184x32 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x184x32_F16E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[46]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %51, 0;\n" + "wgmma.mma_async.sync.aligned.m64n184k32.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45}," + "{%46, %47, %48, %49}," + " %50," + " p, %52, %53;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x184x32_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x184x32 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x184x32_F32E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[92]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + float & d88, float & d89, float & d90, float & d91, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %94, 0;\n" + "wgmma.mma_async.sync.aligned.m64n184k32.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91}," + " %92," + " %93," + " p, %95, %96;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), + "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x184x32_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x184x32 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x184x32_F32E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[92]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + float & d88, float & d89, float & d90, float & d91, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %97, 0;\n" + "wgmma.mma_async.sync.aligned.m64n184k32.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91}," + "{%92, %93, %94, %95}," + " %96," + " p, %98, %99;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), + "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x184x32_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x200x32 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x200x32_F16E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[50]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %52, 0;\n" + "wgmma.mma_async.sync.aligned.m64n200k32.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49}," + " %50," + " %51," + " p, %53, %54;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x200x32_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x200x32 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x200x32_F16E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[50]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %55, 0;\n" + "wgmma.mma_async.sync.aligned.m64n200k32.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49}," + "{%50, %51, %52, %53}," + " %54," + " p, %56, %57;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x200x32_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x200x32 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x200x32_F32E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[100]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %102, 0;\n" + "wgmma.mma_async.sync.aligned.m64n200k32.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99}," + " %100," + " %101," + " p, %103, %104;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x200x32_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x200x32 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x200x32_F32E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[100]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %105, 0;\n" + "wgmma.mma_async.sync.aligned.m64n200k32.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99}," + "{%100, %101, %102, %103}," + " %104," + " p, %106, %107;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x200x32_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x208x32 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x208x32_F16E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[52]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %54, 0;\n" + "wgmma.mma_async.sync.aligned.m64n208k32.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51}," + " %52," + " %53," + " p, %55, %56;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x208x32_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x208x32 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x208x32_F16E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[52]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %57, 0;\n" + "wgmma.mma_async.sync.aligned.m64n208k32.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51}," + "{%52, %53, %54, %55}," + " %56," + " p, %58, %59;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x208x32_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x208x32 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x208x32_F32E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[104]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %106, 0;\n" + "wgmma.mma_async.sync.aligned.m64n208k32.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103}," + " %104," + " %105," + " p, %107, %108;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x208x32_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x208x32 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x208x32_F32E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[104]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %109, 0;\n" + "wgmma.mma_async.sync.aligned.m64n208k32.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103}," + "{%104, %105, %106, %107}," + " %108," + " p, %110, %111;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x208x32_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x216x32 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x216x32_F16E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[54]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %56, 0;\n" + "wgmma.mma_async.sync.aligned.m64n216k32.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53}," + " %54," + " %55," + " p, %57, %58;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x216x32_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x216x32 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x216x32_F16E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[54]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %59, 0;\n" + "wgmma.mma_async.sync.aligned.m64n216k32.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53}," + "{%54, %55, %56, %57}," + " %58," + " p, %60, %61;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x216x32_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x216x32 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x216x32_F32E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[108]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %110, 0;\n" + "wgmma.mma_async.sync.aligned.m64n216k32.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107}," + " %108," + " %109," + " p, %111, %112;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x216x32_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x216x32 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x216x32_F32E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[108]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %113, 0;\n" + "wgmma.mma_async.sync.aligned.m64n216k32.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107}," + "{%108, %109, %110, %111}," + " %112," + " p, %114, %115;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x216x32_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x224x32 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x224x32_F16E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[56]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %58, 0;\n" + "wgmma.mma_async.sync.aligned.m64n224k32.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + " %56," + " %57," + " p, %59, %60;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x224x32_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x224x32 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x224x32_F16E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[56]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %61, 0;\n" + "wgmma.mma_async.sync.aligned.m64n224k32.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + "{%56, %57, %58, %59}," + " %60," + " p, %62, %63;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x224x32_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x224x32 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x224x32_F32E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[112]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %114, 0;\n" + "wgmma.mma_async.sync.aligned.m64n224k32.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111}," + " %112," + " %113," + " p, %115, %116;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x224x32_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x224x32 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x224x32_F32E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[112]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %117, 0;\n" + "wgmma.mma_async.sync.aligned.m64n224k32.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111}," + "{%112, %113, %114, %115}," + " %116," + " p, %118, %119;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x224x32_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x232x32 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x232x32_F16E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[58]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %60, 0;\n" + "wgmma.mma_async.sync.aligned.m64n232k32.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57}," + " %58," + " %59," + " p, %61, %62;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x232x32_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x232x32 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x232x32_F16E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[58]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %63, 0;\n" + "wgmma.mma_async.sync.aligned.m64n232k32.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57}," + "{%58, %59, %60, %61}," + " %62," + " p, %64, %65;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x232x32_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x232x32 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x232x32_F32E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[116]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %118, 0;\n" + "wgmma.mma_async.sync.aligned.m64n232k32.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115}," + " %116," + " %117," + " p, %119, %120;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x232x32_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x232x32 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x232x32_F32E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[116]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %121, 0;\n" + "wgmma.mma_async.sync.aligned.m64n232k32.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115}," + "{%116, %117, %118, %119}," + " %120," + " p, %122, %123;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x232x32_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x240x32 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x240x32_F16E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[60]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %62, 0;\n" + "wgmma.mma_async.sync.aligned.m64n240k32.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59}," + " %60," + " %61," + " p, %63, %64;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x240x32_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x240x32 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x240x32_F16E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[60]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %65, 0;\n" + "wgmma.mma_async.sync.aligned.m64n240k32.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59}," + "{%60, %61, %62, %63}," + " %64," + " p, %66, %67;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x240x32_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x240x32 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x240x32_F32E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[120]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %122, 0;\n" + "wgmma.mma_async.sync.aligned.m64n240k32.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119}," + " %120," + " %121," + " p, %123, %124;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x240x32_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x240x32 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x240x32_F32E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[120]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %125, 0;\n" + "wgmma.mma_async.sync.aligned.m64n240k32.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119}," + "{%120, %121, %122, %123}," + " %124," + " p, %126, %127;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x240x32_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x248x32 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x248x32_F16E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[62]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %64, 0;\n" + "wgmma.mma_async.sync.aligned.m64n248k32.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61}," + " %62," + " %63," + " p, %65, %66;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x248x32_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x248x32 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x248x32_F16E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[62]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %67, 0;\n" + "wgmma.mma_async.sync.aligned.m64n248k32.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61}," + "{%62, %63, %64, %65}," + " %66," + " p, %68, %69;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x248x32_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x248x32 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x248x32_F32E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[124]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + float & d120, float & d121, float & d122, float & d123, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %126, 0;\n" + "wgmma.mma_async.sync.aligned.m64n248k32.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123}," + " %124," + " %125," + " p, %127, %128;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), + "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x248x32_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x248x32 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x248x32_F32E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[124]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + float & d120, float & d121, float & d122, float & d123, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %129, 0;\n" + "wgmma.mma_async.sync.aligned.m64n248k32.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123}," + "{%124, %125, %126, %127}," + " %128," + " p, %130, %131;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), + "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x248x32_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x24x32 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x24x32_F16E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[6]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %8, 0;\n" + "wgmma.mma_async.sync.aligned.m64n24k32.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5}," + " %6," + " %7," + " p, %9, %10;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x24x32_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x24x32 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x24x32_F16E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[6]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %11, 0;\n" + "wgmma.mma_async.sync.aligned.m64n24k32.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5}," + "{%6, %7, %8, %9}," + " %10," + " p, %12, %13;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x24x32_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x24x32 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x24x32_F32E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[12]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %14, 0;\n" + "wgmma.mma_async.sync.aligned.m64n24k32.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + " %12," + " %13," + " p, %15, %16;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x24x32_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x24x32 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x24x32_F32E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[12]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %17, 0;\n" + "wgmma.mma_async.sync.aligned.m64n24k32.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + "{%12, %13, %14, %15}," + " %16," + " p, %18, %19;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x24x32_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x40x32 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x40x32_F16E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[10]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %12, 0;\n" + "wgmma.mma_async.sync.aligned.m64n40k32.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9}," + " %10," + " %11," + " p, %13, %14;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x40x32_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x40x32 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x40x32_F16E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[10]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %15, 0;\n" + "wgmma.mma_async.sync.aligned.m64n40k32.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9}," + "{%10, %11, %12, %13}," + " %14," + " p, %16, %17;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x40x32_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x40x32 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x40x32_F32E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[20]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %22, 0;\n" + "wgmma.mma_async.sync.aligned.m64n40k32.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19}," + " %20," + " %21," + " p, %23, %24;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x40x32_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x40x32 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x40x32_F32E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[20]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %25, 0;\n" + "wgmma.mma_async.sync.aligned.m64n40k32.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19}," + "{%20, %21, %22, %23}," + " %24," + " p, %26, %27;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x40x32_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x48x32 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x48x32_F16E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[12]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %14, 0;\n" + "wgmma.mma_async.sync.aligned.m64n48k32.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + " %12," + " %13," + " p, %15, %16;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x48x32_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x48x32 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x48x32_F16E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[12]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %17, 0;\n" + "wgmma.mma_async.sync.aligned.m64n48k32.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + "{%12, %13, %14, %15}," + " %16," + " p, %18, %19;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x48x32_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x48x32 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x48x32_F32E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[24]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %26, 0;\n" + "wgmma.mma_async.sync.aligned.m64n48k32.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + " %24," + " %25," + " p, %27, %28;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x48x32_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x48x32 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x48x32_F32E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[24]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %29, 0;\n" + "wgmma.mma_async.sync.aligned.m64n48k32.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + "{%24, %25, %26, %27}," + " %28," + " p, %30, %31;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x48x32_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x56x32 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x56x32_F16E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[14]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %16, 0;\n" + "wgmma.mma_async.sync.aligned.m64n56k32.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13}," + " %14," + " %15," + " p, %17, %18;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x56x32_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x56x32 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x56x32_F16E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[14]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %19, 0;\n" + "wgmma.mma_async.sync.aligned.m64n56k32.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13}," + "{%14, %15, %16, %17}," + " %18," + " p, %20, %21;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x56x32_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x56x32 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x56x32_F32E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[28]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %30, 0;\n" + "wgmma.mma_async.sync.aligned.m64n56k32.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27}," + " %28," + " %29," + " p, %31, %32;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x56x32_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x56x32 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x56x32_F32E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[28]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %33, 0;\n" + "wgmma.mma_async.sync.aligned.m64n56k32.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27}," + "{%28, %29, %30, %31}," + " %32," + " p, %34, %35;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x56x32_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x72x32 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x72x32_F16E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[18]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %20, 0;\n" + "wgmma.mma_async.sync.aligned.m64n72k32.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17}," + " %18," + " %19," + " p, %21, %22;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x72x32_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x72x32 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x72x32_F16E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[18]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %23, 0;\n" + "wgmma.mma_async.sync.aligned.m64n72k32.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17}," + "{%18, %19, %20, %21}," + " %22," + " p, %24, %25;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x72x32_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x72x32 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x72x32_F32E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[36]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %38, 0;\n" + "wgmma.mma_async.sync.aligned.m64n72k32.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35}," + " %36," + " %37," + " p, %39, %40;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x72x32_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x72x32 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x72x32_F32E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[36]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %41, 0;\n" + "wgmma.mma_async.sync.aligned.m64n72k32.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35}," + "{%36, %37, %38, %39}," + " %40," + " p, %42, %43;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x72x32_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x80x32 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x80x32_F16E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[20]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %22, 0;\n" + "wgmma.mma_async.sync.aligned.m64n80k32.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19}," + " %20," + " %21," + " p, %23, %24;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x80x32_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x80x32 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x80x32_F16E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[20]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %25, 0;\n" + "wgmma.mma_async.sync.aligned.m64n80k32.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19}," + "{%20, %21, %22, %23}," + " %24," + " p, %26, %27;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x80x32_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x80x32 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x80x32_F32E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[40]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %42, 0;\n" + "wgmma.mma_async.sync.aligned.m64n80k32.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + " %40," + " %41," + " p, %43, %44;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x80x32_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x80x32 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x80x32_F32E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[40]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %45, 0;\n" + "wgmma.mma_async.sync.aligned.m64n80k32.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + "{%40, %41, %42, %43}," + " %44," + " p, %46, %47;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x80x32_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x88x32 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x88x32_F16E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[22]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %24, 0;\n" + "wgmma.mma_async.sync.aligned.m64n88k32.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21}," + " %22," + " %23," + " p, %25, %26;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x88x32_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x88x32 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x88x32_F16E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[22]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %27, 0;\n" + "wgmma.mma_async.sync.aligned.m64n88k32.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21}," + "{%22, %23, %24, %25}," + " %26," + " p, %28, %29;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x88x32_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x88x32 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x88x32_F32E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[44]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %46, 0;\n" + "wgmma.mma_async.sync.aligned.m64n88k32.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43}," + " %44," + " %45," + " p, %47, %48;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x88x32_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x88x32 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x88x32_F32E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[44]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %49, 0;\n" + "wgmma.mma_async.sync.aligned.m64n88k32.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43}," + "{%44, %45, %46, %47}," + " %48," + " p, %50, %51;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x88x32_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x104x32 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x104x32_F16E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[26]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %28, 0;\n" + "wgmma.mma_async.sync.aligned.m64n104k32.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25}," + " %26," + " %27," + " p, %29, %30;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x104x32_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x104x32 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x104x32_F16E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[26]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %31, 0;\n" + "wgmma.mma_async.sync.aligned.m64n104k32.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25}," + "{%26, %27, %28, %29}," + " %30," + " p, %32, %33;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x104x32_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x104x32 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x104x32_F32E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[52]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %54, 0;\n" + "wgmma.mma_async.sync.aligned.m64n104k32.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51}," + " %52," + " %53," + " p, %55, %56;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x104x32_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x104x32 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x104x32_F32E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[52]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %57, 0;\n" + "wgmma.mma_async.sync.aligned.m64n104k32.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51}," + "{%52, %53, %54, %55}," + " %56," + " p, %58, %59;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x104x32_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x112x32 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x112x32_F16E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[28]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %30, 0;\n" + "wgmma.mma_async.sync.aligned.m64n112k32.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27}," + " %28," + " %29," + " p, %31, %32;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x112x32_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x112x32 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x112x32_F16E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[28]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %33, 0;\n" + "wgmma.mma_async.sync.aligned.m64n112k32.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27}," + "{%28, %29, %30, %31}," + " %32," + " p, %34, %35;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x112x32_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x112x32 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x112x32_F32E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[56]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %58, 0;\n" + "wgmma.mma_async.sync.aligned.m64n112k32.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + " %56," + " %57," + " p, %59, %60;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x112x32_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x112x32 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x112x32_F32E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[56]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %61, 0;\n" + "wgmma.mma_async.sync.aligned.m64n112k32.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + "{%56, %57, %58, %59}," + " %60," + " p, %62, %63;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x112x32_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x120x32 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x120x32_F16E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[30]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %32, 0;\n" + "wgmma.mma_async.sync.aligned.m64n120k32.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29}," + " %30," + " %31," + " p, %33, %34;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x120x32_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x120x32 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x120x32_F16E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[30]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %35, 0;\n" + "wgmma.mma_async.sync.aligned.m64n120k32.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29}," + "{%30, %31, %32, %33}," + " %34," + " p, %36, %37;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x120x32_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x120x32 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x120x32_F32E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[60]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %62, 0;\n" + "wgmma.mma_async.sync.aligned.m64n120k32.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59}," + " %60," + " %61," + " p, %63, %64;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x120x32_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x120x32 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x120x32_F32E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[60]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %65, 0;\n" + "wgmma.mma_async.sync.aligned.m64n120k32.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59}," + "{%60, %61, %62, %63}," + " %64," + " p, %66, %67;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x120x32_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x136x32 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x136x32_F16E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[34]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %36, 0;\n" + "wgmma.mma_async.sync.aligned.m64n136k32.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33}," + " %34," + " %35," + " p, %37, %38;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x136x32_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x136x32 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x136x32_F16E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[34]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %39, 0;\n" + "wgmma.mma_async.sync.aligned.m64n136k32.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33}," + "{%34, %35, %36, %37}," + " %38," + " p, %40, %41;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x136x32_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x136x32 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x136x32_F32E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[68]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %70, 0;\n" + "wgmma.mma_async.sync.aligned.m64n136k32.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67}," + " %68," + " %69," + " p, %71, %72;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x136x32_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x136x32 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x136x32_F32E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[68]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %73, 0;\n" + "wgmma.mma_async.sync.aligned.m64n136k32.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67}," + "{%68, %69, %70, %71}," + " %72," + " p, %74, %75;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x136x32_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x144x32 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x144x32_F16E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[36]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %38, 0;\n" + "wgmma.mma_async.sync.aligned.m64n144k32.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35}," + " %36," + " %37," + " p, %39, %40;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x144x32_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x144x32 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x144x32_F16E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[36]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %41, 0;\n" + "wgmma.mma_async.sync.aligned.m64n144k32.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35}," + "{%36, %37, %38, %39}," + " %40," + " p, %42, %43;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x144x32_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x144x32 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x144x32_F32E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[72]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %74, 0;\n" + "wgmma.mma_async.sync.aligned.m64n144k32.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}," + " %72," + " %73," + " p, %75, %76;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x144x32_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x144x32 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x144x32_F32E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[72]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %77, 0;\n" + "wgmma.mma_async.sync.aligned.m64n144k32.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}," + "{%72, %73, %74, %75}," + " %76," + " p, %78, %79;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x144x32_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x152x32 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x152x32_F16E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[38]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %40, 0;\n" + "wgmma.mma_async.sync.aligned.m64n152k32.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37}," + " %38," + " %39," + " p, %41, %42;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x152x32_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x152x32 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x152x32_F16E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[38]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %43, 0;\n" + "wgmma.mma_async.sync.aligned.m64n152k32.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37}," + "{%38, %39, %40, %41}," + " %42," + " p, %44, %45;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x152x32_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x152x32 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x152x32_F32E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[76]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %78, 0;\n" + "wgmma.mma_async.sync.aligned.m64n152k32.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75}," + " %76," + " %77," + " p, %79, %80;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x152x32_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x152x32 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x152x32_F32E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[76]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %81, 0;\n" + "wgmma.mma_async.sync.aligned.m64n152k32.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75}," + "{%76, %77, %78, %79}," + " %80," + " p, %82, %83;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x152x32_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x160x32 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x160x32_F16E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[40]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %42, 0;\n" + "wgmma.mma_async.sync.aligned.m64n160k32.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + " %40," + " %41," + " p, %43, %44;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x160x32_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x160x32 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x160x32_F16E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[40]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %45, 0;\n" + "wgmma.mma_async.sync.aligned.m64n160k32.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + "{%40, %41, %42, %43}," + " %44," + " p, %46, %47;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x160x32_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x160x32 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x160x32_F32E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[80]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %82, 0;\n" + "wgmma.mma_async.sync.aligned.m64n160k32.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}," + " %80," + " %81," + " p, %83, %84;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x160x32_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x160x32 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x160x32_F32E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[80]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %85, 0;\n" + "wgmma.mma_async.sync.aligned.m64n160k32.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}," + "{%80, %81, %82, %83}," + " %84," + " p, %86, %87;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x160x32_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x168x32 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x168x32_F16E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[42]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %44, 0;\n" + "wgmma.mma_async.sync.aligned.m64n168k32.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41}," + " %42," + " %43," + " p, %45, %46;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x168x32_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x168x32 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x168x32_F16E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[42]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %47, 0;\n" + "wgmma.mma_async.sync.aligned.m64n168k32.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41}," + "{%42, %43, %44, %45}," + " %46," + " p, %48, %49;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x168x32_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x168x32 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x168x32_F32E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[84]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %86, 0;\n" + "wgmma.mma_async.sync.aligned.m64n168k32.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83}," + " %84," + " %85," + " p, %87, %88;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x168x32_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x168x32 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x168x32_F32E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[84]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %89, 0;\n" + "wgmma.mma_async.sync.aligned.m64n168k32.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83}," + "{%84, %85, %86, %87}," + " %88," + " p, %90, %91;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x168x32_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x176x32 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x176x32_F16E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[44]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %46, 0;\n" + "wgmma.mma_async.sync.aligned.m64n176k32.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43}," + " %44," + " %45," + " p, %47, %48;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x176x32_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x176x32 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x176x32_F16E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[44]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %49, 0;\n" + "wgmma.mma_async.sync.aligned.m64n176k32.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43}," + "{%44, %45, %46, %47}," + " %48," + " p, %50, %51;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x176x32_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x176x32 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x176x32_F32E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[88]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %90, 0;\n" + "wgmma.mma_async.sync.aligned.m64n176k32.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87}," + " %88," + " %89," + " p, %91, %92;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x176x32_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x176x32 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x176x32_F32E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[88]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %93, 0;\n" + "wgmma.mma_async.sync.aligned.m64n176k32.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87}," + "{%88, %89, %90, %91}," + " %92," + " p, %94, %95;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x176x32_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x184x32 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x184x32_F16E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[46]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %48, 0;\n" + "wgmma.mma_async.sync.aligned.m64n184k32.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45}," + " %46," + " %47," + " p, %49, %50;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x184x32_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x184x32 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x184x32_F16E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[46]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %51, 0;\n" + "wgmma.mma_async.sync.aligned.m64n184k32.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45}," + "{%46, %47, %48, %49}," + " %50," + " p, %52, %53;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x184x32_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x184x32 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x184x32_F32E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[92]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + float & d88, float & d89, float & d90, float & d91, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %94, 0;\n" + "wgmma.mma_async.sync.aligned.m64n184k32.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91}," + " %92," + " %93," + " p, %95, %96;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), + "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x184x32_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x184x32 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x184x32_F32E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[92]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + float & d88, float & d89, float & d90, float & d91, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %97, 0;\n" + "wgmma.mma_async.sync.aligned.m64n184k32.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91}," + "{%92, %93, %94, %95}," + " %96," + " p, %98, %99;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), + "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x184x32_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x200x32 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x200x32_F16E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[50]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %52, 0;\n" + "wgmma.mma_async.sync.aligned.m64n200k32.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49}," + " %50," + " %51," + " p, %53, %54;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x200x32_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x200x32 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x200x32_F16E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[50]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %55, 0;\n" + "wgmma.mma_async.sync.aligned.m64n200k32.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49}," + "{%50, %51, %52, %53}," + " %54," + " p, %56, %57;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x200x32_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x200x32 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x200x32_F32E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[100]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %102, 0;\n" + "wgmma.mma_async.sync.aligned.m64n200k32.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99}," + " %100," + " %101," + " p, %103, %104;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x200x32_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x200x32 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x200x32_F32E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[100]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %105, 0;\n" + "wgmma.mma_async.sync.aligned.m64n200k32.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99}," + "{%100, %101, %102, %103}," + " %104," + " p, %106, %107;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x200x32_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x208x32 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x208x32_F16E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[52]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %54, 0;\n" + "wgmma.mma_async.sync.aligned.m64n208k32.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51}," + " %52," + " %53," + " p, %55, %56;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x208x32_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x208x32 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x208x32_F16E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[52]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %57, 0;\n" + "wgmma.mma_async.sync.aligned.m64n208k32.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51}," + "{%52, %53, %54, %55}," + " %56," + " p, %58, %59;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x208x32_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x208x32 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x208x32_F32E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[104]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %106, 0;\n" + "wgmma.mma_async.sync.aligned.m64n208k32.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103}," + " %104," + " %105," + " p, %107, %108;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x208x32_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x208x32 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x208x32_F32E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[104]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %109, 0;\n" + "wgmma.mma_async.sync.aligned.m64n208k32.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103}," + "{%104, %105, %106, %107}," + " %108," + " p, %110, %111;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x208x32_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x216x32 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x216x32_F16E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[54]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %56, 0;\n" + "wgmma.mma_async.sync.aligned.m64n216k32.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53}," + " %54," + " %55," + " p, %57, %58;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x216x32_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x216x32 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x216x32_F16E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[54]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %59, 0;\n" + "wgmma.mma_async.sync.aligned.m64n216k32.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53}," + "{%54, %55, %56, %57}," + " %58," + " p, %60, %61;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x216x32_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x216x32 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x216x32_F32E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[108]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %110, 0;\n" + "wgmma.mma_async.sync.aligned.m64n216k32.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107}," + " %108," + " %109," + " p, %111, %112;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x216x32_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x216x32 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x216x32_F32E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[108]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %113, 0;\n" + "wgmma.mma_async.sync.aligned.m64n216k32.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107}," + "{%108, %109, %110, %111}," + " %112," + " p, %114, %115;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x216x32_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x224x32 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x224x32_F16E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[56]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %58, 0;\n" + "wgmma.mma_async.sync.aligned.m64n224k32.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + " %56," + " %57," + " p, %59, %60;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x224x32_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x224x32 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x224x32_F16E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[56]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %61, 0;\n" + "wgmma.mma_async.sync.aligned.m64n224k32.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + "{%56, %57, %58, %59}," + " %60," + " p, %62, %63;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x224x32_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x224x32 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x224x32_F32E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[112]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %114, 0;\n" + "wgmma.mma_async.sync.aligned.m64n224k32.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111}," + " %112," + " %113," + " p, %115, %116;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x224x32_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x224x32 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x224x32_F32E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[112]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %117, 0;\n" + "wgmma.mma_async.sync.aligned.m64n224k32.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111}," + "{%112, %113, %114, %115}," + " %116," + " p, %118, %119;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x224x32_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x232x32 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x232x32_F16E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[58]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %60, 0;\n" + "wgmma.mma_async.sync.aligned.m64n232k32.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57}," + " %58," + " %59," + " p, %61, %62;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x232x32_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x232x32 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x232x32_F16E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[58]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %63, 0;\n" + "wgmma.mma_async.sync.aligned.m64n232k32.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57}," + "{%58, %59, %60, %61}," + " %62," + " p, %64, %65;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x232x32_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x232x32 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x232x32_F32E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[116]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %118, 0;\n" + "wgmma.mma_async.sync.aligned.m64n232k32.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115}," + " %116," + " %117," + " p, %119, %120;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x232x32_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x232x32 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x232x32_F32E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[116]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %121, 0;\n" + "wgmma.mma_async.sync.aligned.m64n232k32.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115}," + "{%116, %117, %118, %119}," + " %120," + " p, %122, %123;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x232x32_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x240x32 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x240x32_F16E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[60]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %62, 0;\n" + "wgmma.mma_async.sync.aligned.m64n240k32.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59}," + " %60," + " %61," + " p, %63, %64;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x240x32_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x240x32 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x240x32_F16E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[60]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %65, 0;\n" + "wgmma.mma_async.sync.aligned.m64n240k32.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59}," + "{%60, %61, %62, %63}," + " %64," + " p, %66, %67;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x240x32_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x240x32 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x240x32_F32E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[120]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %122, 0;\n" + "wgmma.mma_async.sync.aligned.m64n240k32.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119}," + " %120," + " %121," + " p, %123, %124;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x240x32_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x240x32 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x240x32_F32E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[120]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %125, 0;\n" + "wgmma.mma_async.sync.aligned.m64n240k32.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119}," + "{%120, %121, %122, %123}," + " %124," + " p, %126, %127;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x240x32_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x248x32 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x248x32_F16E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[62]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %64, 0;\n" + "wgmma.mma_async.sync.aligned.m64n248k32.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61}," + " %62," + " %63," + " p, %65, %66;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x248x32_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x248x32 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x248x32_F16E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[62]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %67, 0;\n" + "wgmma.mma_async.sync.aligned.m64n248k32.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61}," + "{%62, %63, %64, %65}," + " %66," + " p, %68, %69;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x248x32_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x248x32 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x248x32_F32E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[124]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + float & d120, float & d121, float & d122, float & d123, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %126, 0;\n" + "wgmma.mma_async.sync.aligned.m64n248k32.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123}," + " %124," + " %125," + " p, %127, %128;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), + "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x248x32_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x248x32 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x248x32_F32E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[124]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + float & d120, float & d121, float & d122, float & d123, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %129, 0;\n" + "wgmma.mma_async.sync.aligned.m64n248k32.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123}," + "{%124, %125, %126, %127}," + " %128," + " p, %130, %131;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), + "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x248x32_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x24x32 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x24x32_F16E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[6]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %8, 0;\n" + "wgmma.mma_async.sync.aligned.m64n24k32.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5}," + " %6," + " %7," + " p, %9, %10;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x24x32_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x24x32 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x24x32_F16E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[6]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %11, 0;\n" + "wgmma.mma_async.sync.aligned.m64n24k32.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5}," + "{%6, %7, %8, %9}," + " %10," + " p, %12, %13;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x24x32_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x24x32 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x24x32_F32E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[12]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %14, 0;\n" + "wgmma.mma_async.sync.aligned.m64n24k32.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + " %12," + " %13," + " p, %15, %16;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x24x32_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x24x32 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x24x32_F32E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[12]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %17, 0;\n" + "wgmma.mma_async.sync.aligned.m64n24k32.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + "{%12, %13, %14, %15}," + " %16," + " p, %18, %19;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x24x32_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x40x32 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x40x32_F16E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[10]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %12, 0;\n" + "wgmma.mma_async.sync.aligned.m64n40k32.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9}," + " %10," + " %11," + " p, %13, %14;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x40x32_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x40x32 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x40x32_F16E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[10]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %15, 0;\n" + "wgmma.mma_async.sync.aligned.m64n40k32.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9}," + "{%10, %11, %12, %13}," + " %14," + " p, %16, %17;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x40x32_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x40x32 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x40x32_F32E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[20]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %22, 0;\n" + "wgmma.mma_async.sync.aligned.m64n40k32.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19}," + " %20," + " %21," + " p, %23, %24;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x40x32_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x40x32 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x40x32_F32E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[20]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %25, 0;\n" + "wgmma.mma_async.sync.aligned.m64n40k32.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19}," + "{%20, %21, %22, %23}," + " %24," + " p, %26, %27;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x40x32_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x48x32 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x48x32_F16E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[12]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %14, 0;\n" + "wgmma.mma_async.sync.aligned.m64n48k32.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + " %12," + " %13," + " p, %15, %16;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x48x32_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x48x32 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x48x32_F16E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[12]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %17, 0;\n" + "wgmma.mma_async.sync.aligned.m64n48k32.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + "{%12, %13, %14, %15}," + " %16," + " p, %18, %19;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x48x32_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x48x32 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x48x32_F32E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[24]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %26, 0;\n" + "wgmma.mma_async.sync.aligned.m64n48k32.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + " %24," + " %25," + " p, %27, %28;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x48x32_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x48x32 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x48x32_F32E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[24]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %29, 0;\n" + "wgmma.mma_async.sync.aligned.m64n48k32.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + "{%24, %25, %26, %27}," + " %28," + " p, %30, %31;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x48x32_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x56x32 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x56x32_F16E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[14]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %16, 0;\n" + "wgmma.mma_async.sync.aligned.m64n56k32.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13}," + " %14," + " %15," + " p, %17, %18;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x56x32_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x56x32 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x56x32_F16E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[14]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %19, 0;\n" + "wgmma.mma_async.sync.aligned.m64n56k32.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13}," + "{%14, %15, %16, %17}," + " %18," + " p, %20, %21;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x56x32_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x56x32 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x56x32_F32E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[28]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %30, 0;\n" + "wgmma.mma_async.sync.aligned.m64n56k32.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27}," + " %28," + " %29," + " p, %31, %32;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x56x32_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x56x32 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x56x32_F32E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[28]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %33, 0;\n" + "wgmma.mma_async.sync.aligned.m64n56k32.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27}," + "{%28, %29, %30, %31}," + " %32," + " p, %34, %35;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x56x32_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x72x32 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x72x32_F16E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[18]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %20, 0;\n" + "wgmma.mma_async.sync.aligned.m64n72k32.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17}," + " %18," + " %19," + " p, %21, %22;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x72x32_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x72x32 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x72x32_F16E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[18]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %23, 0;\n" + "wgmma.mma_async.sync.aligned.m64n72k32.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17}," + "{%18, %19, %20, %21}," + " %22," + " p, %24, %25;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x72x32_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x72x32 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x72x32_F32E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[36]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %38, 0;\n" + "wgmma.mma_async.sync.aligned.m64n72k32.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35}," + " %36," + " %37," + " p, %39, %40;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x72x32_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x72x32 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x72x32_F32E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[36]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %41, 0;\n" + "wgmma.mma_async.sync.aligned.m64n72k32.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35}," + "{%36, %37, %38, %39}," + " %40," + " p, %42, %43;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x72x32_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x80x32 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x80x32_F16E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[20]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %22, 0;\n" + "wgmma.mma_async.sync.aligned.m64n80k32.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19}," + " %20," + " %21," + " p, %23, %24;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x80x32_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x80x32 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x80x32_F16E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[20]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %25, 0;\n" + "wgmma.mma_async.sync.aligned.m64n80k32.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19}," + "{%20, %21, %22, %23}," + " %24," + " p, %26, %27;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x80x32_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x80x32 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x80x32_F32E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[40]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %42, 0;\n" + "wgmma.mma_async.sync.aligned.m64n80k32.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + " %40," + " %41," + " p, %43, %44;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x80x32_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x80x32 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x80x32_F32E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[40]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %45, 0;\n" + "wgmma.mma_async.sync.aligned.m64n80k32.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + "{%40, %41, %42, %43}," + " %44," + " p, %46, %47;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x80x32_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x88x32 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x88x32_F16E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[22]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %24, 0;\n" + "wgmma.mma_async.sync.aligned.m64n88k32.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21}," + " %22," + " %23," + " p, %25, %26;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x88x32_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x88x32 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x88x32_F16E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[22]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %27, 0;\n" + "wgmma.mma_async.sync.aligned.m64n88k32.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21}," + "{%22, %23, %24, %25}," + " %26," + " p, %28, %29;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x88x32_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x88x32 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x88x32_F32E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[44]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %46, 0;\n" + "wgmma.mma_async.sync.aligned.m64n88k32.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43}," + " %44," + " %45," + " p, %47, %48;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x88x32_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x88x32 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x88x32_F32E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[44]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %49, 0;\n" + "wgmma.mma_async.sync.aligned.m64n88k32.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43}," + "{%44, %45, %46, %47}," + " %48," + " p, %50, %51;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x88x32_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x104x32 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x104x32_F16E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[26]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %28, 0;\n" + "wgmma.mma_async.sync.aligned.m64n104k32.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25}," + " %26," + " %27," + " p, %29, %30;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x104x32_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x104x32 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x104x32_F16E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[26]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %31, 0;\n" + "wgmma.mma_async.sync.aligned.m64n104k32.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25}," + "{%26, %27, %28, %29}," + " %30," + " p, %32, %33;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x104x32_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x104x32 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x104x32_F32E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[52]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %54, 0;\n" + "wgmma.mma_async.sync.aligned.m64n104k32.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51}," + " %52," + " %53," + " p, %55, %56;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x104x32_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x104x32 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x104x32_F32E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[52]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %57, 0;\n" + "wgmma.mma_async.sync.aligned.m64n104k32.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51}," + "{%52, %53, %54, %55}," + " %56," + " p, %58, %59;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x104x32_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x112x32 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x112x32_F16E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[28]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %30, 0;\n" + "wgmma.mma_async.sync.aligned.m64n112k32.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27}," + " %28," + " %29," + " p, %31, %32;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x112x32_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x112x32 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x112x32_F16E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[28]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %33, 0;\n" + "wgmma.mma_async.sync.aligned.m64n112k32.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27}," + "{%28, %29, %30, %31}," + " %32," + " p, %34, %35;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x112x32_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x112x32 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x112x32_F32E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[56]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %58, 0;\n" + "wgmma.mma_async.sync.aligned.m64n112k32.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + " %56," + " %57," + " p, %59, %60;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x112x32_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x112x32 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x112x32_F32E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[56]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %61, 0;\n" + "wgmma.mma_async.sync.aligned.m64n112k32.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + "{%56, %57, %58, %59}," + " %60," + " p, %62, %63;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x112x32_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x120x32 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x120x32_F16E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[30]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %32, 0;\n" + "wgmma.mma_async.sync.aligned.m64n120k32.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29}," + " %30," + " %31," + " p, %33, %34;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x120x32_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x120x32 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x120x32_F16E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[30]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %35, 0;\n" + "wgmma.mma_async.sync.aligned.m64n120k32.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29}," + "{%30, %31, %32, %33}," + " %34," + " p, %36, %37;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x120x32_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x120x32 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x120x32_F32E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[60]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %62, 0;\n" + "wgmma.mma_async.sync.aligned.m64n120k32.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59}," + " %60," + " %61," + " p, %63, %64;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x120x32_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x120x32 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x120x32_F32E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[60]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %65, 0;\n" + "wgmma.mma_async.sync.aligned.m64n120k32.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59}," + "{%60, %61, %62, %63}," + " %64," + " p, %66, %67;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x120x32_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x136x32 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x136x32_F16E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[34]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %36, 0;\n" + "wgmma.mma_async.sync.aligned.m64n136k32.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33}," + " %34," + " %35," + " p, %37, %38;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x136x32_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x136x32 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x136x32_F16E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[34]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %39, 0;\n" + "wgmma.mma_async.sync.aligned.m64n136k32.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33}," + "{%34, %35, %36, %37}," + " %38," + " p, %40, %41;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x136x32_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x136x32 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x136x32_F32E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[68]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %70, 0;\n" + "wgmma.mma_async.sync.aligned.m64n136k32.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67}," + " %68," + " %69," + " p, %71, %72;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x136x32_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x136x32 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x136x32_F32E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[68]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %73, 0;\n" + "wgmma.mma_async.sync.aligned.m64n136k32.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67}," + "{%68, %69, %70, %71}," + " %72," + " p, %74, %75;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x136x32_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x144x32 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x144x32_F16E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[36]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %38, 0;\n" + "wgmma.mma_async.sync.aligned.m64n144k32.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35}," + " %36," + " %37," + " p, %39, %40;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x144x32_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x144x32 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x144x32_F16E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[36]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %41, 0;\n" + "wgmma.mma_async.sync.aligned.m64n144k32.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35}," + "{%36, %37, %38, %39}," + " %40," + " p, %42, %43;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x144x32_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x144x32 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x144x32_F32E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[72]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %74, 0;\n" + "wgmma.mma_async.sync.aligned.m64n144k32.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}," + " %72," + " %73," + " p, %75, %76;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x144x32_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x144x32 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x144x32_F32E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[72]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %77, 0;\n" + "wgmma.mma_async.sync.aligned.m64n144k32.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}," + "{%72, %73, %74, %75}," + " %76," + " p, %78, %79;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x144x32_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x152x32 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x152x32_F16E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[38]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %40, 0;\n" + "wgmma.mma_async.sync.aligned.m64n152k32.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37}," + " %38," + " %39," + " p, %41, %42;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x152x32_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x152x32 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x152x32_F16E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[38]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %43, 0;\n" + "wgmma.mma_async.sync.aligned.m64n152k32.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37}," + "{%38, %39, %40, %41}," + " %42," + " p, %44, %45;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x152x32_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x152x32 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x152x32_F32E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[76]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %78, 0;\n" + "wgmma.mma_async.sync.aligned.m64n152k32.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75}," + " %76," + " %77," + " p, %79, %80;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x152x32_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x152x32 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x152x32_F32E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[76]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %81, 0;\n" + "wgmma.mma_async.sync.aligned.m64n152k32.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75}," + "{%76, %77, %78, %79}," + " %80," + " p, %82, %83;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x152x32_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x160x32 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x160x32_F16E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[40]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %42, 0;\n" + "wgmma.mma_async.sync.aligned.m64n160k32.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + " %40," + " %41," + " p, %43, %44;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x160x32_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x160x32 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x160x32_F16E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[40]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %45, 0;\n" + "wgmma.mma_async.sync.aligned.m64n160k32.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + "{%40, %41, %42, %43}," + " %44," + " p, %46, %47;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x160x32_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x160x32 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x160x32_F32E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[80]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %82, 0;\n" + "wgmma.mma_async.sync.aligned.m64n160k32.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}," + " %80," + " %81," + " p, %83, %84;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x160x32_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x160x32 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x160x32_F32E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[80]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %85, 0;\n" + "wgmma.mma_async.sync.aligned.m64n160k32.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}," + "{%80, %81, %82, %83}," + " %84," + " p, %86, %87;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x160x32_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x168x32 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x168x32_F16E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[42]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %44, 0;\n" + "wgmma.mma_async.sync.aligned.m64n168k32.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41}," + " %42," + " %43," + " p, %45, %46;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x168x32_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x168x32 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x168x32_F16E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[42]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %47, 0;\n" + "wgmma.mma_async.sync.aligned.m64n168k32.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41}," + "{%42, %43, %44, %45}," + " %46," + " p, %48, %49;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x168x32_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x168x32 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x168x32_F32E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[84]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %86, 0;\n" + "wgmma.mma_async.sync.aligned.m64n168k32.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83}," + " %84," + " %85," + " p, %87, %88;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x168x32_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x168x32 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x168x32_F32E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[84]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %89, 0;\n" + "wgmma.mma_async.sync.aligned.m64n168k32.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83}," + "{%84, %85, %86, %87}," + " %88," + " p, %90, %91;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x168x32_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x176x32 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x176x32_F16E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[44]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %46, 0;\n" + "wgmma.mma_async.sync.aligned.m64n176k32.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43}," + " %44," + " %45," + " p, %47, %48;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x176x32_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x176x32 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x176x32_F16E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[44]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %49, 0;\n" + "wgmma.mma_async.sync.aligned.m64n176k32.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43}," + "{%44, %45, %46, %47}," + " %48," + " p, %50, %51;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x176x32_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x176x32 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x176x32_F32E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[88]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %90, 0;\n" + "wgmma.mma_async.sync.aligned.m64n176k32.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87}," + " %88," + " %89," + " p, %91, %92;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x176x32_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x176x32 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x176x32_F32E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[88]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %93, 0;\n" + "wgmma.mma_async.sync.aligned.m64n176k32.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87}," + "{%88, %89, %90, %91}," + " %92," + " p, %94, %95;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x176x32_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x184x32 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x184x32_F16E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[46]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %48, 0;\n" + "wgmma.mma_async.sync.aligned.m64n184k32.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45}," + " %46," + " %47," + " p, %49, %50;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x184x32_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x184x32 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x184x32_F16E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[46]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %51, 0;\n" + "wgmma.mma_async.sync.aligned.m64n184k32.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45}," + "{%46, %47, %48, %49}," + " %50," + " p, %52, %53;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x184x32_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x184x32 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x184x32_F32E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[92]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + float & d88, float & d89, float & d90, float & d91, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %94, 0;\n" + "wgmma.mma_async.sync.aligned.m64n184k32.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91}," + " %92," + " %93," + " p, %95, %96;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), + "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x184x32_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x184x32 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x184x32_F32E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[92]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + float & d88, float & d89, float & d90, float & d91, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %97, 0;\n" + "wgmma.mma_async.sync.aligned.m64n184k32.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91}," + "{%92, %93, %94, %95}," + " %96," + " p, %98, %99;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), + "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x184x32_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x200x32 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x200x32_F16E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[50]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %52, 0;\n" + "wgmma.mma_async.sync.aligned.m64n200k32.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49}," + " %50," + " %51," + " p, %53, %54;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x200x32_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x200x32 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x200x32_F16E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[50]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %55, 0;\n" + "wgmma.mma_async.sync.aligned.m64n200k32.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49}," + "{%50, %51, %52, %53}," + " %54," + " p, %56, %57;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x200x32_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x200x32 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x200x32_F32E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[100]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %102, 0;\n" + "wgmma.mma_async.sync.aligned.m64n200k32.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99}," + " %100," + " %101," + " p, %103, %104;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x200x32_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x200x32 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x200x32_F32E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[100]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %105, 0;\n" + "wgmma.mma_async.sync.aligned.m64n200k32.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99}," + "{%100, %101, %102, %103}," + " %104," + " p, %106, %107;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x200x32_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x208x32 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x208x32_F16E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[52]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %54, 0;\n" + "wgmma.mma_async.sync.aligned.m64n208k32.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51}," + " %52," + " %53," + " p, %55, %56;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x208x32_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x208x32 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x208x32_F16E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[52]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %57, 0;\n" + "wgmma.mma_async.sync.aligned.m64n208k32.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51}," + "{%52, %53, %54, %55}," + " %56," + " p, %58, %59;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x208x32_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x208x32 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x208x32_F32E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[104]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %106, 0;\n" + "wgmma.mma_async.sync.aligned.m64n208k32.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103}," + " %104," + " %105," + " p, %107, %108;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x208x32_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x208x32 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x208x32_F32E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[104]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %109, 0;\n" + "wgmma.mma_async.sync.aligned.m64n208k32.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103}," + "{%104, %105, %106, %107}," + " %108," + " p, %110, %111;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x208x32_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x216x32 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x216x32_F16E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[54]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %56, 0;\n" + "wgmma.mma_async.sync.aligned.m64n216k32.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53}," + " %54," + " %55," + " p, %57, %58;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x216x32_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x216x32 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x216x32_F16E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[54]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %59, 0;\n" + "wgmma.mma_async.sync.aligned.m64n216k32.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53}," + "{%54, %55, %56, %57}," + " %58," + " p, %60, %61;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x216x32_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x216x32 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x216x32_F32E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[108]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %110, 0;\n" + "wgmma.mma_async.sync.aligned.m64n216k32.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107}," + " %108," + " %109," + " p, %111, %112;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x216x32_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x216x32 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x216x32_F32E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[108]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %113, 0;\n" + "wgmma.mma_async.sync.aligned.m64n216k32.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107}," + "{%108, %109, %110, %111}," + " %112," + " p, %114, %115;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x216x32_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x224x32 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x224x32_F16E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[56]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %58, 0;\n" + "wgmma.mma_async.sync.aligned.m64n224k32.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + " %56," + " %57," + " p, %59, %60;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x224x32_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x224x32 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x224x32_F16E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[56]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %61, 0;\n" + "wgmma.mma_async.sync.aligned.m64n224k32.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + "{%56, %57, %58, %59}," + " %60," + " p, %62, %63;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x224x32_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x224x32 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x224x32_F32E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[112]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %114, 0;\n" + "wgmma.mma_async.sync.aligned.m64n224k32.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111}," + " %112," + " %113," + " p, %115, %116;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x224x32_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x224x32 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x224x32_F32E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[112]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %117, 0;\n" + "wgmma.mma_async.sync.aligned.m64n224k32.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111}," + "{%112, %113, %114, %115}," + " %116," + " p, %118, %119;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x224x32_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x232x32 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x232x32_F16E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[58]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %60, 0;\n" + "wgmma.mma_async.sync.aligned.m64n232k32.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57}," + " %58," + " %59," + " p, %61, %62;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x232x32_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x232x32 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x232x32_F16E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[58]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %63, 0;\n" + "wgmma.mma_async.sync.aligned.m64n232k32.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57}," + "{%58, %59, %60, %61}," + " %62," + " p, %64, %65;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x232x32_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x232x32 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x232x32_F32E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[116]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %118, 0;\n" + "wgmma.mma_async.sync.aligned.m64n232k32.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115}," + " %116," + " %117," + " p, %119, %120;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x232x32_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x232x32 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x232x32_F32E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[116]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %121, 0;\n" + "wgmma.mma_async.sync.aligned.m64n232k32.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115}," + "{%116, %117, %118, %119}," + " %120," + " p, %122, %123;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x232x32_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x240x32 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x240x32_F16E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[60]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %62, 0;\n" + "wgmma.mma_async.sync.aligned.m64n240k32.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59}," + " %60," + " %61," + " p, %63, %64;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x240x32_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x240x32 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x240x32_F16E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[60]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %65, 0;\n" + "wgmma.mma_async.sync.aligned.m64n240k32.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59}," + "{%60, %61, %62, %63}," + " %64," + " p, %66, %67;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x240x32_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x240x32 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x240x32_F32E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[120]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %122, 0;\n" + "wgmma.mma_async.sync.aligned.m64n240k32.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119}," + " %120," + " %121," + " p, %123, %124;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x240x32_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x240x32 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x240x32_F32E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[120]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %125, 0;\n" + "wgmma.mma_async.sync.aligned.m64n240k32.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119}," + "{%120, %121, %122, %123}," + " %124," + " p, %126, %127;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x240x32_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x248x32 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x248x32_F16E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[62]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %64, 0;\n" + "wgmma.mma_async.sync.aligned.m64n248k32.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61}," + " %62," + " %63," + " p, %65, %66;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x248x32_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x248x32 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x248x32_F16E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[62]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %67, 0;\n" + "wgmma.mma_async.sync.aligned.m64n248k32.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61}," + "{%62, %63, %64, %65}," + " %66," + " p, %68, %69;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x248x32_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x248x32 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x248x32_F32E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[124]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + float & d120, float & d121, float & d122, float & d123, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %126, 0;\n" + "wgmma.mma_async.sync.aligned.m64n248k32.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123}," + " %124," + " %125," + " p, %127, %128;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), + "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x248x32_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x248x32 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x248x32_F32E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[124]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + float & d120, float & d121, float & d122, float & d123, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %129, 0;\n" + "wgmma.mma_async.sync.aligned.m64n248k32.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123}," + "{%124, %125, %126, %127}," + " %128," + " p, %130, %131;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), + "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x248x32_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace SM90::GMMA + +} // namespace cute diff --git a/include/cute/arch/mma_sm90_gmma_sparse.hpp b/include/cute/arch/mma_sm90_gmma_sparse.hpp index d05e762e1..ecca91b93 100644 --- a/include/cute/arch/mma_sm90_gmma_sparse.hpp +++ b/include/cute/arch/mma_sm90_gmma_sparse.hpp @@ -36,14 +36,10 @@ namespace cute { -//////////////////////////////////////////////////////////////////////////////////////////////////// -// GMMA PTX definitions: C = (scaleA * A) * (scaleB * B) + (scaleD * C) -//////////////////////////////////////////////////////////////////////////////////////////////////// - -//////////////////////////////////////////////////////////////////////////////////////////////////// - namespace SM90::GMMA::SPARSE { +//////////////////////////////////////////////////////////////////////////////////////////////////// +// GMMA PTX definitions: C = (scaleA * A) * (scaleB * B) + (scaleD * C) //////////////////////////////////////////////////////////////////////////////////////////////////// // SPARSE GMMA 64x8x32 F16+=F16*F16 @@ -70,6 +66,7 @@ struct GMMA_64x8x32_F16F16F16_SS GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -121,6 +118,7 @@ struct GMMA_64x8x32_F16F16F16_RS GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -169,6 +167,7 @@ struct GMMA_64x16x32_F16F16F16_SS GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -220,6 +219,7 @@ struct GMMA_64x16x32_F16F16F16_RS GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -269,6 +269,7 @@ struct GMMA_64x32x32_F16F16F16_SS GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -322,6 +323,7 @@ struct GMMA_64x32x32_F16F16F16_RS GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -347,119 +349,6 @@ struct GMMA_64x32x32_F16F16F16_RS //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x48x32 F16+=F16*F16 -template < - GMMA::Major tnspA, - GMMA::Major tnspB, - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x48x32_F16F16F16_SS -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[12]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %16, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n48k32.f16.f16.f16 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11}," - " %12," - " %13," - " %14, %15," - " p, %17, %18, %19, %20;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11) - : "l"(desc_a), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x48x32_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x48x32 F16+=F16*F16 -template < - GMMA::Major tnspA, - GMMA::Major tnspB, - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x48x32_F16F16F16_RS -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[12]; - - static_assert(tnspA == GMMA::Major::K, - "Register source operand A must have K major layout."); - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %19, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n48k32.f16.f16.f16 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11}," - "{%12, %13, %14, %15}," - " %16," - " %17, %18," - " p, %20, %21, %22;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x48x32_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - // SPARSE GMMA 64x64x32 F16+=F16*F16 template < GMMA::Major tnspA, @@ -487,6 +376,7 @@ struct GMMA_64x64x32_F16F16F16_SS GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -545,6 +435,7 @@ struct GMMA_64x64x32_F16F16F16_RS GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -573,129 +464,6 @@ struct GMMA_64x64x32_F16F16F16_RS //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x80x32 F16+=F16*F16 -template < - GMMA::Major tnspA, - GMMA::Major tnspB, - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x80x32_F16F16F16_SS -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[20]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %24, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n80k32.f16.f16.f16 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19}," - " %20," - " %21," - " %22, %23," - " p, %25, %26, %27, %28;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19) - : "l"(desc_a), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x80x32_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x80x32 F16+=F16*F16 -template < - GMMA::Major tnspA, - GMMA::Major tnspB, - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x80x32_F16F16F16_RS -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[20]; - - static_assert(tnspA == GMMA::Major::K, - "Register source operand A must have K major layout."); - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %27, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n80k32.f16.f16.f16 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19}," - "{%20, %21, %22, %23}," - " %24," - " %25, %26," - " p, %28, %29, %30;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x80x32_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - // SPARSE GMMA 64x96x32 F16+=F16*F16 template < GMMA::Major tnspA, @@ -725,6 +493,7 @@ struct GMMA_64x96x32_F16F16F16_SS GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -788,6 +557,7 @@ struct GMMA_64x96x32_F16F16F16_RS GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -819,139 +589,6 @@ struct GMMA_64x96x32_F16F16F16_RS //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x112x32 F16+=F16*F16 -template < - GMMA::Major tnspA, - GMMA::Major tnspB, - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x112x32_F16F16F16_SS -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[28]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %32, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n112k32.f16.f16.f16 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27}," - " %28," - " %29," - " %30, %31," - " p, %33, %34, %35, %36;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27) - : "l"(desc_a), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x112x32_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x112x32 F16+=F16*F16 -template < - GMMA::Major tnspA, - GMMA::Major tnspB, - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x112x32_F16F16F16_RS -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[28]; - - static_assert(tnspA == GMMA::Major::K, - "Register source operand A must have K major layout."); - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %35, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n112k32.f16.f16.f16 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27}," - "{%28, %29, %30, %31}," - " %32," - " %33, %34," - " p, %36, %37, %38;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x112x32_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - // SPARSE GMMA 64x128x32 F16+=F16*F16 template < GMMA::Major tnspA, @@ -983,6 +620,7 @@ struct GMMA_64x128x32_F16F16F16_SS GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -1051,6 +689,7 @@ struct GMMA_64x128x32_F16F16F16_RS GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -1085,8 +724,7 @@ struct GMMA_64x128x32_F16F16F16_RS //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x144x32 F16+=F16*F16 +// SPARSE GMMA 64x192x32 F16+=F16*F16 template < GMMA::Major tnspA, GMMA::Major tnspB, @@ -1094,13 +732,13 @@ template < GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, GMMA::SparseSel spsel = GMMA::SparseSel::Zero > -struct GMMA_64x144x32_F16F16F16_SS +struct GMMA_64x192x32_F16F16F16_SS { using DRegisters = void; using ARegisters = uint64_t[1]; using ERegisters = uint32_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[36]; + using CRegisters = uint32_t[48]; CUTE_HOST_DEVICE static void fma(uint64_t const& desc_a, @@ -1114,24 +752,29 @@ struct GMMA_64x144x32_F16F16F16_SS uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, uint32_t const& e, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %40, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n144k32.f16.f16.f16 " + "setp.ne.b32 p, %52, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n192k32.f16.f16.f16 " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " " %16, %17, %18, %19, %20, %21, %22, %23, " " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35}," - " %36," - " %37," - " %38, %39," - " p, %41, %42, %43, %44;\n" + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + " %48," + " %49," + " %50, %51," + " p, %53, %54, %55, %56;\n" "}\n" : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), @@ -1141,22 +784,23 @@ struct GMMA_64x144x32_F16F16F16_SS "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35) + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) : "l"(desc_a), "l"(desc_b), "r"(e), "n"(int32_t(spsel)), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x144x32_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x192x32_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x144x32 F16+=F16*F16 +// SPARSE GMMA 64x192x32 F16+=F16*F16 template < GMMA::Major tnspA, GMMA::Major tnspB, @@ -1164,13 +808,13 @@ template < GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, GMMA::SparseSel spsel = GMMA::SparseSel::Zero > -struct GMMA_64x144x32_F16F16F16_RS +struct GMMA_64x192x32_F16F16F16_RS { using DRegisters = void; using ARegisters = uint32_t[4]; using ERegisters = uint32_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[36]; + using CRegisters = uint32_t[48]; static_assert(tnspA == GMMA::Major::K, "Register source operand A must have K major layout."); @@ -1187,24 +831,29 @@ struct GMMA_64x144x32_F16F16F16_RS uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, uint32_t const& e, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %43, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n144k32.f16.f16.f16 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35}," - "{%36, %37, %38, %39}," - " %40," - " %41, %42," - " p, %44, %45, %46;\n" + "setp.ne.b32 p, %55, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n192k32.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + "{%48, %49, %50, %51}," + " %52," + " %53, %54," + " p, %56, %57, %58;\n" "}\n" : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), @@ -1214,22 +863,23 @@ struct GMMA_64x144x32_F16F16F16_RS "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35) + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) : "r"(a00), "r"(a01), "r"(a02), "r"(a03), "l"(desc_b), "r"(e), "n"(int32_t(spsel)), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x144x32_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x192x32_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x160x32 F16+=F16*F16 +// SPARSE GMMA 64x256x32 F16+=F16*F16 template < GMMA::Major tnspA, GMMA::Major tnspB, @@ -1237,13 +887,13 @@ template < GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, GMMA::SparseSel spsel = GMMA::SparseSel::Zero > -struct GMMA_64x160x32_F16F16F16_SS +struct GMMA_64x256x32_F16F16F16_SS { using DRegisters = void; using ARegisters = uint64_t[1]; using ERegisters = uint32_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[40]; + using CRegisters = uint32_t[64]; CUTE_HOST_DEVICE static void fma(uint64_t const& desc_a, @@ -1258,24 +908,34 @@ struct GMMA_64x160x32_F16F16F16_SS uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, uint32_t const& e, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %44, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n160k32.f16.f16.f16 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39}," - " %40," - " %41," - " %42, %43," - " p, %45, %46, %47, %48;\n" + "setp.ne.b32 p, %68, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n256k32.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + " %64," + " %65," + " %66, %67," + " p, %69, %70, %71, %72;\n" "}\n" : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), @@ -1286,22 +946,26 @@ struct GMMA_64x160x32_F16F16F16_SS "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) : "l"(desc_a), "l"(desc_b), "r"(e), "n"(int32_t(spsel)), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x160x32_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x256x32_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x160x32 F16+=F16*F16 +// SPARSE GMMA 64x256x32 F16+=F16*F16 template < GMMA::Major tnspA, GMMA::Major tnspB, @@ -1309,13 +973,13 @@ template < GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, GMMA::SparseSel spsel = GMMA::SparseSel::Zero > -struct GMMA_64x160x32_F16F16F16_RS +struct GMMA_64x256x32_F16F16F16_RS { using DRegisters = void; using ARegisters = uint32_t[4]; using ERegisters = uint32_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[40]; + using CRegisters = uint32_t[64]; static_assert(tnspA == GMMA::Major::K, "Register source operand A must have K major layout."); @@ -1333,24 +997,34 @@ struct GMMA_64x160x32_F16F16F16_RS uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, uint32_t const& e, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %47, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n160k32.f16.f16.f16 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39}," - "{%40, %41, %42, %43}," - " %44," - " %45, %46," - " p, %48, %49, %50;\n" + "setp.ne.b32 p, %71, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n256k32.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + "{%64, %65, %66, %67}," + " %68," + " %69, %70," + " p, %72, %73, %74;\n" "}\n" : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), @@ -1361,22 +1035,26 @@ struct GMMA_64x160x32_F16F16F16_RS "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) : "r"(a00), "r"(a01), "r"(a02), "r"(a03), "l"(desc_b), "r"(e), "n"(int32_t(spsel)), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x160x32_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x256x32_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x176x32 F16+=F16*F16 +// SPARSE GMMA 64x8x32 F32+=F16*F16 template < GMMA::Major tnspA, GMMA::Major tnspB, @@ -1384,74 +1062,48 @@ template < GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, GMMA::SparseSel spsel = GMMA::SparseSel::Zero > -struct GMMA_64x176x32_F16F16F16_SS +struct GMMA_64x8x32_F32F16F16_SS { using DRegisters = void; using ARegisters = uint64_t[1]; using ERegisters = uint32_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[44]; + using CRegisters = float[4]; CUTE_HOST_DEVICE static void fma(uint64_t const& desc_a, uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + float & d0, float & d1, float & d2, float & d3, uint32_t const& e, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %48, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n176k32.f16.f16.f16 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43}," - " %44," - " %45," - " %46, %47," - " p, %49, %50, %51, %52;\n" + "setp.ne.b32 p, %8, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n8k32.f32.f16.f16 " + "{%0, %1, %2, %3}," + " %4," + " %5," + " %6, %7," + " p, %9, %10, %11, %12;\n" "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43) + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3) : "l"(desc_a), "l"(desc_b), "r"(e), "n"(int32_t(spsel)), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x176x32_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x8x32_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x176x32 F16+=F16*F16 +// SPARSE GMMA 64x8x32 F32+=F16*F16 template < GMMA::Major tnspA, GMMA::Major tnspB, @@ -1459,76 +1111,51 @@ template < GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, GMMA::SparseSel spsel = GMMA::SparseSel::Zero > -struct GMMA_64x176x32_F16F16F16_RS +struct GMMA_64x8x32_F32F16F16_RS { using DRegisters = void; using ARegisters = uint32_t[4]; using ERegisters = uint32_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[44]; + using CRegisters = float[4]; static_assert(tnspA == GMMA::Major::K, "Register source operand A must have K major layout."); CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + float & d0, float & d1, float & d2, float & d3, uint32_t const& e, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %51, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n176k32.f16.f16.f16 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43}," - "{%44, %45, %46, %47}," - " %48," - " %49, %50," - " p, %52, %53, %54;\n" + "setp.ne.b32 p, %11, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n8k32.f32.f16.f16 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + " %8," + " %9, %10," + " p, %12, %13, %14;\n" "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), "l"(desc_b), "r"(e), "n"(int32_t(spsel)), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x176x32_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x8x32_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -// SPARSE GMMA 64x192x32 F16+=F16*F16 +// SPARSE GMMA 64x16x32 F32+=F16*F16 template < GMMA::Major tnspA, GMMA::Major tnspB, @@ -1536,74 +1163,50 @@ template < GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, GMMA::SparseSel spsel = GMMA::SparseSel::Zero > -struct GMMA_64x192x32_F16F16F16_SS +struct GMMA_64x16x32_F32F16F16_SS { using DRegisters = void; using ARegisters = uint64_t[1]; using ERegisters = uint32_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[48]; + using CRegisters = float[8]; CUTE_HOST_DEVICE static void fma(uint64_t const& desc_a, uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + float & d0, float & d1, float & d2, float & d3, + float & d4, float & d5, float & d6, float & d7, uint32_t const& e, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %52, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n192k32.f16.f16.f16 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47}," - " %48," - " %49," - " %50, %51," - " p, %53, %54, %55, %56;\n" + "setp.ne.b32 p, %12, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n16k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + " %8," + " %9," + " %10, %11," + " p, %13, %14, %15, %16;\n" "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), - "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3), + "+f"(d4), "+f"(d5), "+f"(d6), "+f"(d7) : "l"(desc_a), "l"(desc_b), "r"(e), "n"(int32_t(spsel)), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x192x32_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x16x32_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// SPARSE GMMA 64x192x32 F16+=F16*F16 +// SPARSE GMMA 64x16x32 F32+=F16*F16 template < GMMA::Major tnspA, GMMA::Major tnspB, @@ -1611,78 +1214,53 @@ template < GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, GMMA::SparseSel spsel = GMMA::SparseSel::Zero > -struct GMMA_64x192x32_F16F16F16_RS +struct GMMA_64x16x32_F32F16F16_RS { using DRegisters = void; using ARegisters = uint32_t[4]; using ERegisters = uint32_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[48]; + using CRegisters = float[8]; static_assert(tnspA == GMMA::Major::K, "Register source operand A must have K major layout."); CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + float & d0, float & d1, float & d2, float & d3, + float & d4, float & d5, float & d6, float & d7, uint32_t const& e, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %55, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n192k32.f16.f16.f16 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47}," - "{%48, %49, %50, %51}," - " %52," - " %53, %54," - " p, %56, %57, %58;\n" + "setp.ne.b32 p, %15, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n16k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + "{%8, %9, %10, %11}," + " %12," + " %13, %14," + " p, %16, %17, %18;\n" "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), - "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3), + "+f"(d4), "+f"(d5), "+f"(d6), "+f"(d7) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), "l"(desc_b), "r"(e), "n"(int32_t(spsel)), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x192x32_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x16x32_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x208x32 F16+=F16*F16 +// SPARSE GMMA 64x32x32 F32+=F16*F16 template < GMMA::Major tnspA, GMMA::Major tnspB, @@ -1690,79 +1268,55 @@ template < GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, GMMA::SparseSel spsel = GMMA::SparseSel::Zero > -struct GMMA_64x208x32_F16F16F16_SS +struct GMMA_64x32x32_F32F16F16_SS { using DRegisters = void; using ARegisters = uint64_t[1]; using ERegisters = uint32_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[52]; + using CRegisters = float[16]; CUTE_HOST_DEVICE static void fma(uint64_t const& desc_a, uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, - uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, uint32_t const& e, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %56, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n208k32.f16.f16.f16 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51}," - " %52," - " %53," - " %54, %55," - " p, %57, %58, %59, %60;\n" + "setp.ne.b32 p, %20, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n32k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + " %16," + " %17," + " %18, %19," + " p, %21, %22, %23, %24;\n" "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), - "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), - "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51) + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15) : "l"(desc_a), "l"(desc_b), "r"(e), "n"(int32_t(spsel)), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x208x32_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x32x32_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x208x32 F16+=F16*F16 +// SPARSE GMMA 64x32x32 F32+=F16*F16 template < GMMA::Major tnspA, GMMA::Major tnspB, @@ -1770,13 +1324,13 @@ template < GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, GMMA::SparseSel spsel = GMMA::SparseSel::Zero > -struct GMMA_64x208x32_F16F16F16_RS +struct GMMA_64x32x32_F32F16F16_RS { using DRegisters = void; using ARegisters = uint32_t[4]; using ERegisters = uint32_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[52]; + using CRegisters = float[16]; static_assert(tnspA == GMMA::Major::K, "Register source operand A must have K major layout."); @@ -1784,68 +1338,44 @@ struct GMMA_64x208x32_F16F16F16_RS CUTE_HOST_DEVICE static void fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, - uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, uint32_t const& e, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %59, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n208k32.f16.f16.f16 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51}," - "{%52, %53, %54, %55}," - " %56," - " %57, %58," - " p, %60, %61, %62;\n" + "setp.ne.b32 p, %23, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n32k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + "{%16, %17, %18, %19}," + " %20," + " %21, %22," + " p, %24, %25, %26;\n" "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), - "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), - "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51) + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15) : "r"(a00), "r"(a01), "r"(a02), "r"(a03), "l"(desc_b), "r"(e), "n"(int32_t(spsel)), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x208x32_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x32x32_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x224x32 F16+=F16*F16 +// SPARSE GMMA 64x64x32 F32+=F16*F16 template < GMMA::Major tnspA, GMMA::Major tnspB, @@ -1853,81 +1383,65 @@ template < GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, GMMA::SparseSel spsel = GMMA::SparseSel::Zero > -struct GMMA_64x224x32_F16F16F16_SS +struct GMMA_64x64x32_F32F16F16_SS { using DRegisters = void; using ARegisters = uint64_t[1]; using ERegisters = uint32_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[56]; + using CRegisters = float[32]; CUTE_HOST_DEVICE static void fma(uint64_t const& desc_a, uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, - uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, - uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, uint32_t const& e, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %60, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n224k32.f16.f16.f16 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55}," - " %56," - " %57," - " %58, %59," - " p, %61, %62, %63, %64;\n" + "setp.ne.b32 p, %36, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n64k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + " %32," + " %33," + " %34, %35," + " p, %37, %38, %39, %40;\n" "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), - "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), - "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), - "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31) : "l"(desc_a), "l"(desc_b), "r"(e), "n"(int32_t(spsel)), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x224x32_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x64x32_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x224x32 F16+=F16*F16 +// SPARSE GMMA 64x64x32 F32+=F16*F16 template < GMMA::Major tnspA, GMMA::Major tnspB, @@ -1935,13 +1449,13 @@ template < GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, GMMA::SparseSel spsel = GMMA::SparseSel::Zero > -struct GMMA_64x224x32_F16F16F16_RS +struct GMMA_64x64x32_F32F16F16_RS { using DRegisters = void; using ARegisters = uint32_t[4]; using ERegisters = uint32_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[56]; + using CRegisters = float[32]; static_assert(tnspA == GMMA::Major::K, "Register source operand A must have K major layout."); @@ -1949,70 +1463,54 @@ struct GMMA_64x224x32_F16F16F16_RS CUTE_HOST_DEVICE static void fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, - uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, - uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, uint32_t const& e, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %63, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n224k32.f16.f16.f16 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55}," - "{%56, %57, %58, %59}," - " %60," - " %61, %62," - " p, %64, %65, %66;\n" + "setp.ne.b32 p, %39, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n64k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + "{%32, %33, %34, %35}," + " %36," + " %37, %38," + " p, %40, %41, %42;\n" "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), - "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), - "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), - "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31) : "r"(a00), "r"(a01), "r"(a02), "r"(a03), "l"(desc_b), "r"(e), "n"(int32_t(spsel)), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x224x32_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x64x32_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x240x32 F16+=F16*F16 +// SPARSE GMMA 64x96x32 F32+=F16*F16 template < GMMA::Major tnspA, GMMA::Major tnspB, @@ -2020,84 +1518,75 @@ template < GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, GMMA::SparseSel spsel = GMMA::SparseSel::Zero > -struct GMMA_64x240x32_F16F16F16_SS +struct GMMA_64x96x32_F32F16F16_SS { using DRegisters = void; using ARegisters = uint64_t[1]; using ERegisters = uint32_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[60]; + using CRegisters = float[48]; CUTE_HOST_DEVICE static void fma(uint64_t const& desc_a, uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, - uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, - uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, - uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, uint32_t const& e, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %64, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n240k32.f16.f16.f16 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59}," - " %60," - " %61," - " %62, %63," - " p, %65, %66, %67, %68;\n" + "setp.ne.b32 p, %52, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n96k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + " %48," + " %49," + " %50, %51," + " p, %53, %54, %55, %56;\n" "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), - "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), - "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), - "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), - "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59) + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47) : "l"(desc_a), "l"(desc_b), "r"(e), "n"(int32_t(spsel)), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x240x32_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x96x32_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x240x32 F16+=F16*F16 +// SPARSE GMMA 64x96x32 F32+=F16*F16 template < GMMA::Major tnspA, GMMA::Major tnspB, @@ -2105,13 +1594,13 @@ template < GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, GMMA::SparseSel spsel = GMMA::SparseSel::Zero > -struct GMMA_64x240x32_F16F16F16_RS +struct GMMA_64x96x32_F32F16F16_RS { using DRegisters = void; using ARegisters = uint32_t[4]; using ERegisters = uint32_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[60]; + using CRegisters = float[48]; static_assert(tnspA == GMMA::Major::K, "Register source operand A must have K major layout."); @@ -2119,72 +1608,64 @@ struct GMMA_64x240x32_F16F16F16_RS CUTE_HOST_DEVICE static void fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, - uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, - uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, - uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, uint32_t const& e, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %67, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n240k32.f16.f16.f16 " + "setp.ne.b32 p, %55, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n96k32.f32.f16.f16 " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " " %16, %17, %18, %19, %20, %21, %22, %23, " " %24, %25, %26, %27, %28, %29, %30, %31, " " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59}," - "{%60, %61, %62, %63}," - " %64," - " %65, %66," - " p, %68, %69, %70;\n" + " %40, %41, %42, %43, %44, %45, %46, %47}," + "{%48, %49, %50, %51}," + " %52," + " %53, %54," + " p, %56, %57, %58;\n" "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), - "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), - "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), - "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), - "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59) + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47) : "r"(a00), "r"(a01), "r"(a02), "r"(a03), "l"(desc_b), "r"(e), "n"(int32_t(spsel)), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x240x32_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x96x32_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -// SPARSE GMMA 64x256x32 F16+=F16*F16 +// SPARSE GMMA 64x128x32 F32+=F16*F16 template < GMMA::Major tnspA, GMMA::Major tnspB, @@ -2192,42 +1673,43 @@ template < GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, GMMA::SparseSel spsel = GMMA::SparseSel::Zero > -struct GMMA_64x256x32_F16F16F16_SS +struct GMMA_64x128x32_F32F16F16_SS { using DRegisters = void; using ARegisters = uint64_t[1]; using ERegisters = uint32_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[64]; + using CRegisters = float[64]; CUTE_HOST_DEVICE static void fma(uint64_t const& desc_a, uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, - uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, - uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, - uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, - uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, uint32_t const& e, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" "setp.ne.b32 p, %68, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n256k32.f16.f16.f16 " + "wgmma.mma_async.sp.sync.aligned.m64n128k32.f32.f16.f16 " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " " %16, %17, %18, %19, %20, %21, %22, %23, " @@ -2241,35 +1723,35 @@ struct GMMA_64x256x32_F16F16F16_SS " %66, %67," " p, %69, %70, %71, %72;\n" "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), - "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), - "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), - "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), - "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), - "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63) : "l"(desc_a), "l"(desc_b), "r"(e), "n"(int32_t(spsel)), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x256x32_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x128x32_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// SPARSE GMMA 64x256x32 F16+=F16*F16 +// SPARSE GMMA 64x128x32 F32+=F16*F16 template < GMMA::Major tnspA, GMMA::Major tnspB, @@ -2277,13 +1759,13 @@ template < GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, GMMA::SparseSel spsel = GMMA::SparseSel::Zero > -struct GMMA_64x256x32_F16F16F16_RS +struct GMMA_64x128x32_F32F16F16_RS { using DRegisters = void; using ARegisters = uint32_t[4]; using ERegisters = uint32_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[64]; + using CRegisters = float[64]; static_assert(tnspA == GMMA::Major::K, "Register source operand A must have K major layout."); @@ -2291,31 +1773,32 @@ struct GMMA_64x256x32_F16F16F16_RS CUTE_HOST_DEVICE static void fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, - uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, - uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, - uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, - uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, uint32_t const& e, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" "setp.ne.b32 p, %71, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n256k32.f16.f16.f16 " + "wgmma.mma_async.sp.sync.aligned.m64n128k32.f32.f16.f16 " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " " %16, %17, %18, %19, %20, %21, %22, %23, " @@ -2329,35 +1812,35 @@ struct GMMA_64x256x32_F16F16F16_RS " %69, %70," " p, %72, %73, %74;\n" "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), - "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), - "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), - "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), - "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), - "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63) : "r"(a00), "r"(a01), "r"(a02), "r"(a03), "l"(desc_b), "r"(e), "n"(int32_t(spsel)), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x256x32_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x128x32_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// SPARSE GMMA 64x8x32 F32+=F16*F16 +// SPARSE GMMA 64x192x32 F32+=F16*F16 template < GMMA::Major tnspA, GMMA::Major tnspB, @@ -2365,47 +1848,105 @@ template < GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, GMMA::SparseSel spsel = GMMA::SparseSel::Zero > -struct GMMA_64x8x32_F32F16F16_SS +struct GMMA_64x192x32_F32F16F16_SS { using DRegisters = void; using ARegisters = uint64_t[1]; using ERegisters = uint32_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = float[4]; + using CRegisters = float[96]; CUTE_HOST_DEVICE static void fma(uint64_t const& desc_a, uint64_t const& desc_b, - float & d0, float & d1, float & d2, float & d3, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + float & d88, float & d89, float & d90, float & d91, + float & d92, float & d93, float & d94, float & d95, uint32_t const& e, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %8, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n8k32.f32.f16.f16 " - "{%0, %1, %2, %3}," - " %4," - " %5," - " %6, %7," - " p, %9, %10, %11, %12;\n" + "setp.ne.b32 p, %100, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n192k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + " %96," + " %97," + " %98, %99," + " p, %101, %102, %103, %104;\n" "}\n" - : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3) + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), + "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91), + "+f"(d92), "+f"(d93), "+f"(d94), "+f"(d95) : "l"(desc_a), "l"(desc_b), "r"(e), "n"(int32_t(spsel)), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x8x32_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x192x32_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// SPARSE GMMA 64x8x32 F32+=F16*F16 +// SPARSE GMMA 64x192x32 F32+=F16*F16 template < GMMA::Major tnspA, GMMA::Major tnspB, @@ -2413,100 +1954,412 @@ template < GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, GMMA::SparseSel spsel = GMMA::SparseSel::Zero > -struct GMMA_64x8x32_F32F16F16_RS +struct GMMA_64x192x32_F32F16F16_RS { using DRegisters = void; using ARegisters = uint32_t[4]; using ERegisters = uint32_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = float[4]; + using CRegisters = float[96]; static_assert(tnspA == GMMA::Major::K, "Register source operand A must have K major layout."); CUTE_HOST_DEVICE static void - fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, uint64_t const& desc_b, - float & d0, float & d1, float & d2, float & d3, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + float & d88, float & d89, float & d90, float & d91, + float & d92, float & d93, float & d94, float & d95, uint32_t const& e, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %11, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n8k32.f32.f16.f16 " - "{%0, %1, %2, %3}," - "{%4, %5, %6, %7}," - " %8," - " %9, %10," - " p, %12, %13, %14;\n" + "setp.ne.b32 p, %103, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n192k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + "{%96, %97, %98, %99}," + " %100," + " %101, %102," + " p, %104, %105, %106;\n" "}\n" - : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3) - : "r"(a0), "r"(a1), "r"(a2), "r"(a3), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x8x32_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// SPARSE GMMA 64x16x32 F32+=F16*F16 -template < - GMMA::Major tnspA, - GMMA::Major tnspB, - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x16x32_F32F16F16_SS -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), + "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91), + "+f"(d92), "+f"(d93), "+f"(d94), "+f"(d95) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x192x32_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x256x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x256x32_F32F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; using ERegisters = uint32_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = float[8]; + using CRegisters = float[128]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + float & d120, float & d121, float & d122, float & d123, + float & d124, float & d125, float & d126, float & d127, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %132, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n256k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + " %128," + " %129," + " %130, %131," + " p, %133, %134, %135, %136;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), + "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123), + "+f"(d124), "+f"(d125), "+f"(d126), "+f"(d127) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x256x32_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x256x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x256x32_F32F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[128]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + float & d120, float & d121, float & d122, float & d123, + float & d124, float & d125, float & d126, float & d127, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %135, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n256k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + "{%128, %129, %130, %131}," + " %132," + " %133, %134," + " p, %136, %137, %138;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), + "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123), + "+f"(d124), "+f"(d125), "+f"(d126), "+f"(d127) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x256x32_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x8x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x8x32_F32BF16BF16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[4]; CUTE_HOST_DEVICE static void fma(uint64_t const& desc_a, uint64_t const& desc_b, float & d0, float & d1, float & d2, float & d3, - float & d4, float & d5, float & d6, float & d7, uint32_t const& e, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %12, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n16k32.f32.f16.f16 " - "{%0, %1, %2, %3, %4, %5, %6, %7}," - " %8," - " %9," - " %10, %11," - " p, %13, %14, %15, %16;\n" + "setp.ne.b32 p, %8, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n8k32.f32.bf16.bf16 " + "{%0, %1, %2, %3}," + " %4," + " %5," + " %6, %7," + " p, %9, %10, %11, %12;\n" "}\n" - : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3), - "+f"(d4), "+f"(d5), "+f"(d6), "+f"(d7) + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3) : "l"(desc_a), "l"(desc_b), "r"(e), "n"(int32_t(spsel)), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x16x32_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x8x32_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// SPARSE GMMA 64x16x32 F32+=F16*F16 +// SPARSE GMMA 64x8x32 F32+=BF16*BF16 template < GMMA::Major tnspA, GMMA::Major tnspB, @@ -2514,13 +2367,13 @@ template < GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, GMMA::SparseSel spsel = GMMA::SparseSel::Zero > -struct GMMA_64x16x32_F32F16F16_RS +struct GMMA_64x8x32_F32BF16BF16_RS { using DRegisters = void; using ARegisters = uint32_t[4]; using ERegisters = uint32_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = float[8]; + using CRegisters = float[4]; static_assert(tnspA == GMMA::Major::K, "Register source operand A must have K major layout."); @@ -2529,37 +2382,36 @@ struct GMMA_64x16x32_F32F16F16_RS fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, uint64_t const& desc_b, float & d0, float & d1, float & d2, float & d3, - float & d4, float & d5, float & d6, float & d7, uint32_t const& e, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %15, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n16k32.f32.f16.f16 " - "{%0, %1, %2, %3, %4, %5, %6, %7}," - "{%8, %9, %10, %11}," - " %12," - " %13, %14," - " p, %16, %17, %18;\n" + "setp.ne.b32 p, %11, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n8k32.f32.bf16.bf16 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + " %8," + " %9, %10," + " p, %12, %13, %14;\n" "}\n" - : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3), - "+f"(d4), "+f"(d5), "+f"(d6), "+f"(d7) + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3) : "r"(a0), "r"(a1), "r"(a2), "r"(a3), "l"(desc_b), "r"(e), "n"(int32_t(spsel)), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x16x32_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x8x32_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// SPARSE GMMA 64x32x32 F32+=F16*F16 +// SPARSE GMMA 64x16x32 F32+=BF16*BF16 template < GMMA::Major tnspA, GMMA::Major tnspB, @@ -2567,54 +2419,50 @@ template < GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, GMMA::SparseSel spsel = GMMA::SparseSel::Zero > -struct GMMA_64x32x32_F32F16F16_SS +struct GMMA_64x16x32_F32BF16BF16_SS { using DRegisters = void; using ARegisters = uint64_t[1]; using ERegisters = uint32_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = float[16]; + using CRegisters = float[8]; CUTE_HOST_DEVICE static void fma(uint64_t const& desc_a, uint64_t const& desc_b, - float & d00, float & d01, float & d02, float & d03, - float & d04, float & d05, float & d06, float & d07, - float & d08, float & d09, float & d10, float & d11, - float & d12, float & d13, float & d14, float & d15, + float & d0, float & d1, float & d2, float & d3, + float & d4, float & d5, float & d6, float & d7, uint32_t const& e, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %20, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n32k32.f32.f16.f16 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15}," - " %16," - " %17," - " %18, %19," - " p, %21, %22, %23, %24;\n" + "setp.ne.b32 p, %12, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n16k32.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + " %8," + " %9," + " %10, %11," + " p, %13, %14, %15, %16;\n" "}\n" - : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), - "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), - "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), - "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15) + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3), + "+f"(d4), "+f"(d5), "+f"(d6), "+f"(d7) : "l"(desc_a), "l"(desc_b), "r"(e), "n"(int32_t(spsel)), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x32x32_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x16x32_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// SPARSE GMMA 64x32x32 F32+=F16*F16 +// SPARSE GMMA 64x16x32 F32+=BF16*BF16 template < GMMA::Major tnspA, GMMA::Major tnspB, @@ -2622,58 +2470,53 @@ template < GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, GMMA::SparseSel spsel = GMMA::SparseSel::Zero > -struct GMMA_64x32x32_F32F16F16_RS +struct GMMA_64x16x32_F32BF16BF16_RS { using DRegisters = void; using ARegisters = uint32_t[4]; using ERegisters = uint32_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = float[16]; + using CRegisters = float[8]; static_assert(tnspA == GMMA::Major::K, "Register source operand A must have K major layout."); CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, uint64_t const& desc_b, - float & d00, float & d01, float & d02, float & d03, - float & d04, float & d05, float & d06, float & d07, - float & d08, float & d09, float & d10, float & d11, - float & d12, float & d13, float & d14, float & d15, + float & d0, float & d1, float & d2, float & d3, + float & d4, float & d5, float & d6, float & d7, uint32_t const& e, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %23, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n32k32.f32.f16.f16 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15}," - "{%16, %17, %18, %19}," - " %20," - " %21, %22," - " p, %24, %25, %26;\n" + "setp.ne.b32 p, %15, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n16k32.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + "{%8, %9, %10, %11}," + " %12," + " %13, %14," + " p, %16, %17, %18;\n" "}\n" - : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), - "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), - "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), - "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3), + "+f"(d4), "+f"(d5), "+f"(d6), "+f"(d7) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), "l"(desc_b), "r"(e), "n"(int32_t(spsel)), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x32x32_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x16x32_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x48x32 F32+=F16*F16 +// SPARSE GMMA 64x32x32 F32+=BF16*BF16 template < GMMA::Major tnspA, GMMA::Major tnspB, @@ -2681,13 +2524,13 @@ template < GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, GMMA::SparseSel spsel = GMMA::SparseSel::Zero > -struct GMMA_64x48x32_F32F16F16_SS +struct GMMA_64x32x32_F32BF16BF16_SS { using DRegisters = void; using ARegisters = uint64_t[1]; using ERegisters = uint32_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = float[24]; + using CRegisters = float[16]; CUTE_HOST_DEVICE static void fma(uint64_t const& desc_a, @@ -2696,46 +2539,40 @@ struct GMMA_64x48x32_F32F16F16_SS float & d04, float & d05, float & d06, float & d07, float & d08, float & d09, float & d10, float & d11, float & d12, float & d13, float & d14, float & d15, - float & d16, float & d17, float & d18, float & d19, - float & d20, float & d21, float & d22, float & d23, uint32_t const& e, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %28, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n48k32.f32.f16.f16 " + "setp.ne.b32 p, %20, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n32k32.f32.bf16.bf16 " "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23}," - " %24," - " %25," - " %26, %27," - " p, %29, %30, %31, %32;\n" + " %8, %9, %10, %11, %12, %13, %14, %15}," + " %16," + " %17," + " %18, %19," + " p, %21, %22, %23, %24;\n" "}\n" : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), - "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), - "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), - "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23) + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15) : "l"(desc_a), "l"(desc_b), "r"(e), "n"(int32_t(spsel)), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x48x32_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x32x32_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x48x32 F32+=F16*F16 +// SPARSE GMMA 64x32x32 F32+=BF16*BF16 template < GMMA::Major tnspA, GMMA::Major tnspB, @@ -2743,13 +2580,13 @@ template < GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, GMMA::SparseSel spsel = GMMA::SparseSel::Zero > -struct GMMA_64x48x32_F32F16F16_RS +struct GMMA_64x32x32_F32BF16BF16_RS { using DRegisters = void; using ARegisters = uint32_t[4]; using ERegisters = uint32_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = float[24]; + using CRegisters = float[16]; static_assert(tnspA == GMMA::Major::K, "Register source operand A must have K major layout."); @@ -2761,45 +2598,40 @@ struct GMMA_64x48x32_F32F16F16_RS float & d04, float & d05, float & d06, float & d07, float & d08, float & d09, float & d10, float & d11, float & d12, float & d13, float & d14, float & d15, - float & d16, float & d17, float & d18, float & d19, - float & d20, float & d21, float & d22, float & d23, uint32_t const& e, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %31, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n48k32.f32.f16.f16 " + "setp.ne.b32 p, %23, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n32k32.f32.bf16.bf16 " "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23}," - "{%24, %25, %26, %27}," - " %28," - " %29, %30," - " p, %32, %33, %34;\n" + " %8, %9, %10, %11, %12, %13, %14, %15}," + "{%16, %17, %18, %19}," + " %20," + " %21, %22," + " p, %24, %25, %26;\n" "}\n" : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), - "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), - "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), - "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23) + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15) : "r"(a00), "r"(a01), "r"(a02), "r"(a03), "l"(desc_b), "r"(e), "n"(int32_t(spsel)), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x48x32_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x32x32_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -// SPARSE GMMA 64x64x32 F32+=F16*F16 +// SPARSE GMMA 64x64x32 F32+=BF16*BF16 template < GMMA::Major tnspA, GMMA::Major tnspB, @@ -2807,7 +2639,7 @@ template < GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, GMMA::SparseSel spsel = GMMA::SparseSel::Zero > -struct GMMA_64x64x32_F32F16F16_SS +struct GMMA_64x64x32_F32BF16BF16_SS { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -2830,11 +2662,12 @@ struct GMMA_64x64x32_F32F16F16_SS GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" "setp.ne.b32 p, %36, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n64k32.f32.f16.f16 " + "wgmma.mma_async.sp.sync.aligned.m64n64k32.f32.bf16.bf16 " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " " %16, %17, %18, %19, %20, %21, %22, %23, " @@ -2857,14 +2690,14 @@ struct GMMA_64x64x32_F32F16F16_SS "r"(e), "n"(int32_t(spsel)), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x64x32_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x64x32_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// SPARSE GMMA 64x64x32 F32+=F16*F16 +// SPARSE GMMA 64x64x32 F32+=BF16*BF16 template < GMMA::Major tnspA, GMMA::Major tnspB, @@ -2872,7 +2705,7 @@ template < GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, GMMA::SparseSel spsel = GMMA::SparseSel::Zero > -struct GMMA_64x64x32_F32F16F16_RS +struct GMMA_64x64x32_F32BF16BF16_RS { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -2898,11 +2731,12 @@ struct GMMA_64x64x32_F32F16F16_RS GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" "setp.ne.b32 p, %39, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n64k32.f32.f16.f16 " + "wgmma.mma_async.sp.sync.aligned.m64n64k32.f32.bf16.bf16 " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " " %16, %17, %18, %19, %20, %21, %22, %23, " @@ -2925,15 +2759,14 @@ struct GMMA_64x64x32_F32F16F16_RS "r"(e), "n"(int32_t(spsel)), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x64x32_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x64x32_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x80x32 F32+=F16*F16 +// SPARSE GMMA 64x96x32 F32+=BF16*BF16 template < GMMA::Major tnspA, GMMA::Major tnspB, @@ -2941,13 +2774,13 @@ template < GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, GMMA::SparseSel spsel = GMMA::SparseSel::Zero > -struct GMMA_64x80x32_F32F16F16_SS +struct GMMA_64x96x32_F32BF16BF16_SS { using DRegisters = void; using ARegisters = uint64_t[1]; using ERegisters = uint32_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = float[40]; + using CRegisters = float[48]; CUTE_HOST_DEVICE static void fma(uint64_t const& desc_a, @@ -2962,24 +2795,28 @@ struct GMMA_64x80x32_F32F16F16_SS float & d28, float & d29, float & d30, float & d31, float & d32, float & d33, float & d34, float & d35, float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, uint32_t const& e, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %44, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n80k32.f32.f16.f16 " + "setp.ne.b32 p, %52, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n96k32.f32.bf16.bf16 " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " " %16, %17, %18, %19, %20, %21, %22, %23, " " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39}," - " %40," - " %41," - " %42, %43," - " p, %45, %46, %47, %48;\n" + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + " %48," + " %49," + " %50, %51," + " p, %53, %54, %55, %56;\n" "}\n" : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), @@ -2990,22 +2827,22 @@ struct GMMA_64x80x32_F32F16F16_SS "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), - "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39) + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47) : "l"(desc_a), "l"(desc_b), "r"(e), "n"(int32_t(spsel)), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x80x32_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x96x32_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x80x32 F32+=F16*F16 +// SPARSE GMMA 64x96x32 F32+=BF16*BF16 template < GMMA::Major tnspA, GMMA::Major tnspB, @@ -3013,162 +2850,13 @@ template < GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, GMMA::SparseSel spsel = GMMA::SparseSel::Zero > -struct GMMA_64x80x32_F32F16F16_RS +struct GMMA_64x96x32_F32BF16BF16_RS { using DRegisters = void; using ARegisters = uint32_t[4]; using ERegisters = uint32_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = float[40]; - - static_assert(tnspA == GMMA::Major::K, - "Register source operand A must have K major layout."); - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, - uint64_t const& desc_b, - float & d00, float & d01, float & d02, float & d03, - float & d04, float & d05, float & d06, float & d07, - float & d08, float & d09, float & d10, float & d11, - float & d12, float & d13, float & d14, float & d15, - float & d16, float & d17, float & d18, float & d19, - float & d20, float & d21, float & d22, float & d23, - float & d24, float & d25, float & d26, float & d27, - float & d28, float & d29, float & d30, float & d31, - float & d32, float & d33, float & d34, float & d35, - float & d36, float & d37, float & d38, float & d39, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %47, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n80k32.f32.f16.f16 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39}," - "{%40, %41, %42, %43}," - " %44," - " %45, %46," - " p, %48, %49, %50;\n" - "}\n" - : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), - "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), - "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), - "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), - "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), - "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), - "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), - "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), - "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), - "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x80x32_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// SPARSE GMMA 64x96x32 F32+=F16*F16 -template < - GMMA::Major tnspA, - GMMA::Major tnspB, - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x96x32_F32F16F16_SS -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = float[48]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - float & d00, float & d01, float & d02, float & d03, - float & d04, float & d05, float & d06, float & d07, - float & d08, float & d09, float & d10, float & d11, - float & d12, float & d13, float & d14, float & d15, - float & d16, float & d17, float & d18, float & d19, - float & d20, float & d21, float & d22, float & d23, - float & d24, float & d25, float & d26, float & d27, - float & d28, float & d29, float & d30, float & d31, - float & d32, float & d33, float & d34, float & d35, - float & d36, float & d37, float & d38, float & d39, - float & d40, float & d41, float & d42, float & d43, - float & d44, float & d45, float & d46, float & d47, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %52, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n96k32.f32.f16.f16 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47}," - " %48," - " %49," - " %50, %51," - " p, %53, %54, %55, %56;\n" - "}\n" - : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), - "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), - "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), - "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), - "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), - "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), - "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), - "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), - "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), - "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), - "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), - "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47) - : "l"(desc_a), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x96x32_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// SPARSE GMMA 64x96x32 F32+=F16*F16 -template < - GMMA::Major tnspA, - GMMA::Major tnspB, - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x96x32_F32F16F16_RS -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = float[48]; + using CRegisters = float[48]; static_assert(tnspA == GMMA::Major::K, "Register source operand A must have K major layout."); @@ -3192,11 +2880,12 @@ struct GMMA_64x96x32_F32F16F16_RS GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" "setp.ne.b32 p, %55, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n96k32.f32.f16.f16 " + "wgmma.mma_async.sp.sync.aligned.m64n96k32.f32.bf16.bf16 " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " " %16, %17, %18, %19, %20, %21, %22, %23, " @@ -3225,181 +2914,14 @@ struct GMMA_64x96x32_F32F16F16_RS "r"(e), "n"(int32_t(spsel)), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x96x32_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x112x32 F32+=F16*F16 -template < - GMMA::Major tnspA, - GMMA::Major tnspB, - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x112x32_F32F16F16_SS -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = float[56]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - float & d00, float & d01, float & d02, float & d03, - float & d04, float & d05, float & d06, float & d07, - float & d08, float & d09, float & d10, float & d11, - float & d12, float & d13, float & d14, float & d15, - float & d16, float & d17, float & d18, float & d19, - float & d20, float & d21, float & d22, float & d23, - float & d24, float & d25, float & d26, float & d27, - float & d28, float & d29, float & d30, float & d31, - float & d32, float & d33, float & d34, float & d35, - float & d36, float & d37, float & d38, float & d39, - float & d40, float & d41, float & d42, float & d43, - float & d44, float & d45, float & d46, float & d47, - float & d48, float & d49, float & d50, float & d51, - float & d52, float & d53, float & d54, float & d55, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %60, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n112k32.f32.f16.f16 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55}," - " %56," - " %57," - " %58, %59," - " p, %61, %62, %63, %64;\n" - "}\n" - : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), - "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), - "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), - "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), - "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), - "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), - "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), - "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), - "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), - "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), - "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), - "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), - "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), - "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55) - : "l"(desc_a), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x112x32_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x112x32 F32+=F16*F16 -template < - GMMA::Major tnspA, - GMMA::Major tnspB, - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x112x32_F32F16F16_RS -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = float[56]; - - static_assert(tnspA == GMMA::Major::K, - "Register source operand A must have K major layout."); - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, - uint64_t const& desc_b, - float & d00, float & d01, float & d02, float & d03, - float & d04, float & d05, float & d06, float & d07, - float & d08, float & d09, float & d10, float & d11, - float & d12, float & d13, float & d14, float & d15, - float & d16, float & d17, float & d18, float & d19, - float & d20, float & d21, float & d22, float & d23, - float & d24, float & d25, float & d26, float & d27, - float & d28, float & d29, float & d30, float & d31, - float & d32, float & d33, float & d34, float & d35, - float & d36, float & d37, float & d38, float & d39, - float & d40, float & d41, float & d42, float & d43, - float & d44, float & d45, float & d46, float & d47, - float & d48, float & d49, float & d50, float & d51, - float & d52, float & d53, float & d54, float & d55, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %63, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n112k32.f32.f16.f16 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55}," - "{%56, %57, %58, %59}," - " %60," - " %61, %62," - " p, %64, %65, %66;\n" - "}\n" - : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), - "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), - "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), - "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), - "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), - "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), - "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), - "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), - "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), - "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), - "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), - "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), - "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), - "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x112x32_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x96x32_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -// SPARSE GMMA 64x128x32 F32+=F16*F16 +// SPARSE GMMA 64x128x32 F32+=BF16*BF16 template < GMMA::Major tnspA, GMMA::Major tnspB, @@ -3407,7 +2929,7 @@ template < GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, GMMA::SparseSel spsel = GMMA::SparseSel::Zero > -struct GMMA_64x128x32_F32F16F16_SS +struct GMMA_64x128x32_F32BF16BF16_SS { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -3438,11 +2960,12 @@ struct GMMA_64x128x32_F32F16F16_SS GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" "setp.ne.b32 p, %68, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n128k32.f32.f16.f16 " + "wgmma.mma_async.sp.sync.aligned.m64n128k32.f32.bf16.bf16 " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " " %16, %17, %18, %19, %20, %21, %22, %23, " @@ -3477,14 +3000,14 @@ struct GMMA_64x128x32_F32F16F16_SS "r"(e), "n"(int32_t(spsel)), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x128x32_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x128x32_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// SPARSE GMMA 64x128x32 F32+=F16*F16 +// SPARSE GMMA 64x128x32 F32+=BF16*BF16 template < GMMA::Major tnspA, GMMA::Major tnspB, @@ -3492,7 +3015,7 @@ template < GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, GMMA::SparseSel spsel = GMMA::SparseSel::Zero > -struct GMMA_64x128x32_F32F16F16_RS +struct GMMA_64x128x32_F32BF16BF16_RS { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -3526,11 +3049,12 @@ struct GMMA_64x128x32_F32F16F16_RS GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" "setp.ne.b32 p, %71, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n128k32.f32.f16.f16 " + "wgmma.mma_async.sp.sync.aligned.m64n128k32.f32.bf16.bf16 " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " " %16, %17, %18, %19, %20, %21, %22, %23, " @@ -3565,15 +3089,14 @@ struct GMMA_64x128x32_F32F16F16_RS "r"(e), "n"(int32_t(spsel)), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x128x32_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x128x32_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x144x32 F32+=F16*F16 +// SPARSE GMMA 64x192x32 F32+=BF16*BF16 template < GMMA::Major tnspA, GMMA::Major tnspB, @@ -3581,13 +3104,13 @@ template < GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, GMMA::SparseSel spsel = GMMA::SparseSel::Zero > -struct GMMA_64x144x32_F32F16F16_SS +struct GMMA_64x192x32_F32BF16BF16_SS { using DRegisters = void; using ARegisters = uint64_t[1]; using ERegisters = uint32_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = float[72]; + using CRegisters = float[96]; CUTE_HOST_DEVICE static void fma(uint64_t const& desc_a, @@ -3610,15 +3133,22 @@ struct GMMA_64x144x32_F32F16F16_SS float & d60, float & d61, float & d62, float & d63, float & d64, float & d65, float & d66, float & d67, float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + float & d88, float & d89, float & d90, float & d91, + float & d92, float & d93, float & d94, float & d95, uint32_t const& e, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %76, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n144k32.f32.f16.f16 " + "setp.ne.b32 p, %100, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n192k32.f32.bf16.bf16 " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " " %16, %17, %18, %19, %20, %21, %22, %23, " @@ -3627,11 +3157,14 @@ struct GMMA_64x144x32_F32F16F16_SS " %40, %41, %42, %43, %44, %45, %46, %47, " " %48, %49, %50, %51, %52, %53, %54, %55, " " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71}," - " %72," - " %73," - " %74, %75," - " p, %77, %78, %79, %80;\n" + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + " %96," + " %97," + " %98, %99," + " p, %101, %102, %103, %104;\n" "}\n" : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), @@ -3650,22 +3183,26 @@ struct GMMA_64x144x32_F32F16F16_SS "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), - "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71) + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), + "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91), + "+f"(d92), "+f"(d93), "+f"(d94), "+f"(d95) : "l"(desc_a), "l"(desc_b), "r"(e), "n"(int32_t(spsel)), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x144x32_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x192x32_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x144x32 F32+=F16*F16 +// SPARSE GMMA 64x192x32 F32+=BF16*BF16 template < GMMA::Major tnspA, GMMA::Major tnspB, @@ -3673,13 +3210,13 @@ template < GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, GMMA::SparseSel spsel = GMMA::SparseSel::Zero > -struct GMMA_64x144x32_F32F16F16_RS +struct GMMA_64x192x32_F32BF16BF16_RS { using DRegisters = void; using ARegisters = uint32_t[4]; using ERegisters = uint32_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = float[72]; + using CRegisters = float[96]; static_assert(tnspA == GMMA::Major::K, "Register source operand A must have K major layout."); @@ -3705,15 +3242,22 @@ struct GMMA_64x144x32_F32F16F16_RS float & d60, float & d61, float & d62, float & d63, float & d64, float & d65, float & d66, float & d67, float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + float & d88, float & d89, float & d90, float & d91, + float & d92, float & d93, float & d94, float & d95, uint32_t const& e, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %79, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n144k32.f32.f16.f16 " + "setp.ne.b32 p, %103, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n192k32.f32.bf16.bf16 " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " " %16, %17, %18, %19, %20, %21, %22, %23, " @@ -3722,11 +3266,14 @@ struct GMMA_64x144x32_F32F16F16_RS " %40, %41, %42, %43, %44, %45, %46, %47, " " %48, %49, %50, %51, %52, %53, %54, %55, " " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71}," - "{%72, %73, %74, %75}," - " %76," - " %77, %78," - " p, %80, %81, %82;\n" + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + "{%96, %97, %98, %99}," + " %100," + " %101, %102," + " p, %104, %105, %106;\n" "}\n" : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), @@ -3745,22 +3292,26 @@ struct GMMA_64x144x32_F32F16F16_RS "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), - "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71) + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), + "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91), + "+f"(d92), "+f"(d93), "+f"(d94), "+f"(d95) : "r"(a00), "r"(a01), "r"(a02), "r"(a03), "l"(desc_b), "r"(e), "n"(int32_t(spsel)), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x144x32_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x192x32_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x160x32 F32+=F16*F16 +// SPARSE GMMA 64x256x32 F32+=BF16*BF16 template < GMMA::Major tnspA, GMMA::Major tnspB, @@ -3768,46 +3319,59 @@ template < GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, GMMA::SparseSel spsel = GMMA::SparseSel::Zero > -struct GMMA_64x160x32_F32F16F16_SS +struct GMMA_64x256x32_F32BF16BF16_SS { using DRegisters = void; using ARegisters = uint64_t[1]; using ERegisters = uint32_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = float[80]; + using CRegisters = float[128]; CUTE_HOST_DEVICE static void fma(uint64_t const& desc_a, uint64_t const& desc_b, - float & d00, float & d01, float & d02, float & d03, - float & d04, float & d05, float & d06, float & d07, - float & d08, float & d09, float & d10, float & d11, - float & d12, float & d13, float & d14, float & d15, - float & d16, float & d17, float & d18, float & d19, - float & d20, float & d21, float & d22, float & d23, - float & d24, float & d25, float & d26, float & d27, - float & d28, float & d29, float & d30, float & d31, - float & d32, float & d33, float & d34, float & d35, - float & d36, float & d37, float & d38, float & d39, - float & d40, float & d41, float & d42, float & d43, - float & d44, float & d45, float & d46, float & d47, - float & d48, float & d49, float & d50, float & d51, - float & d52, float & d53, float & d54, float & d55, - float & d56, float & d57, float & d58, float & d59, - float & d60, float & d61, float & d62, float & d63, - float & d64, float & d65, float & d66, float & d67, - float & d68, float & d69, float & d70, float & d71, - float & d72, float & d73, float & d74, float & d75, - float & d76, float & d77, float & d78, float & d79, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + float & d120, float & d121, float & d122, float & d123, + float & d124, float & d125, float & d126, float & d127, uint32_t const& e, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %84, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n160k32.f32.f16.f16 " + "setp.ne.b32 p, %132, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n256k32.f32.bf16.bf16 " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " " %16, %17, %18, %19, %20, %21, %22, %23, " @@ -3817,47 +3381,63 @@ struct GMMA_64x160x32_F32F16F16_SS " %48, %49, %50, %51, %52, %53, %54, %55, " " %56, %57, %58, %59, %60, %61, %62, %63, " " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79}," - " %80," - " %81," - " %82, %83," - " p, %85, %86, %87, %88;\n" + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + " %128," + " %129," + " %130, %131," + " p, %133, %134, %135, %136;\n" "}\n" - : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), - "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), - "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), - "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), - "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), - "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), - "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), - "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), - "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), - "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), - "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), - "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), - "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), - "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), - "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), - "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), - "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), - "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), - "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), - "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79) + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), + "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123), + "+f"(d124), "+f"(d125), "+f"(d126), "+f"(d127) : "l"(desc_a), "l"(desc_b), "r"(e), "n"(int32_t(spsel)), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x160x32_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x256x32_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x160x32 F32+=F16*F16 +// SPARSE GMMA 64x256x32 F32+=BF16*BF16 template < GMMA::Major tnspA, GMMA::Major tnspB, @@ -3865,49 +3445,62 @@ template < GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, GMMA::SparseSel spsel = GMMA::SparseSel::Zero > -struct GMMA_64x160x32_F32F16F16_RS +struct GMMA_64x256x32_F32BF16BF16_RS { using DRegisters = void; using ARegisters = uint32_t[4]; using ERegisters = uint32_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = float[80]; + using CRegisters = float[128]; static_assert(tnspA == GMMA::Major::K, "Register source operand A must have K major layout."); CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, uint64_t const& desc_b, - float & d00, float & d01, float & d02, float & d03, - float & d04, float & d05, float & d06, float & d07, - float & d08, float & d09, float & d10, float & d11, - float & d12, float & d13, float & d14, float & d15, - float & d16, float & d17, float & d18, float & d19, - float & d20, float & d21, float & d22, float & d23, - float & d24, float & d25, float & d26, float & d27, - float & d28, float & d29, float & d30, float & d31, - float & d32, float & d33, float & d34, float & d35, - float & d36, float & d37, float & d38, float & d39, - float & d40, float & d41, float & d42, float & d43, - float & d44, float & d45, float & d46, float & d47, - float & d48, float & d49, float & d50, float & d51, - float & d52, float & d53, float & d54, float & d55, - float & d56, float & d57, float & d58, float & d59, - float & d60, float & d61, float & d62, float & d63, - float & d64, float & d65, float & d66, float & d67, - float & d68, float & d69, float & d70, float & d71, - float & d72, float & d73, float & d74, float & d75, - float & d76, float & d77, float & d78, float & d79, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + float & d120, float & d121, float & d122, float & d123, + float & d124, float & d125, float & d126, float & d127, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %87, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n160k32.f32.f16.f16 " + "setp.ne.b32 p, %135, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n256k32.f32.bf16.bf16 " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " " %16, %17, %18, %19, %20, %21, %22, %23, " @@ -3917,61 +3510,267 @@ struct GMMA_64x160x32_F32F16F16_RS " %48, %49, %50, %51, %52, %53, %54, %55, " " %56, %57, %58, %59, %60, %61, %62, %63, " " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79}," - "{%80, %81, %82, %83}," - " %84," - " %85, %86," - " p, %88, %89, %90;\n" + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + "{%128, %129, %130, %131}," + " %132," + " %133, %134," + " p, %136, %137, %138;\n" "}\n" - : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), - "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), - "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), - "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), - "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), - "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), - "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), - "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), - "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), - "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), - "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), - "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), - "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), - "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), - "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), - "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), - "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), - "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), - "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), - "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), + "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123), + "+f"(d124), "+f"(d125), "+f"(d126), "+f"(d127) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), "l"(desc_b), "r"(e), "n"(int32_t(spsel)), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x160x32_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x256x32_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x8x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x8x16_F32TF32TF32_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[4]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d0, float & d1, float & d2, float & d3, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %8, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n8k16.f32.tf32.tf32 " + "{%0, %1, %2, %3}," + " %4," + " %5," + " %6, %7," + " p, %9, %10;\n" + "}\n" + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x8x16_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x8x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x8x16_F32TF32TF32_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + float & d0, float & d1, float & d2, float & d3, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %11, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n8k16.f32.tf32.tf32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + " %8," + " %9, %10," + " p, %12, %13;\n" + "}\n" + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x8x16_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x16x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x16x16_F32TF32TF32_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[8]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d0, float & d1, float & d2, float & d3, + float & d4, float & d5, float & d6, float & d7, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %12, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n16k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + " %8," + " %9," + " %10, %11," + " p, %13, %14;\n" + "}\n" + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3), + "+f"(d4), "+f"(d5), "+f"(d6), "+f"(d7) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x16x16_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif + } +}; //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x176x32 F32+=F16*F16 +// SPARSE GMMA 64x16x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x16x16_F32TF32TF32_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[8]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + float & d0, float & d1, float & d2, float & d3, + float & d4, float & d5, float & d6, float & d7, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %15, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n16k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + "{%8, %9, %10, %11}," + " %12," + " %13, %14," + " p, %16, %17;\n" + "}\n" + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3), + "+f"(d4), "+f"(d5), "+f"(d6), "+f"(d7) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x16x16_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x32x16 TN F32+=TF32*TF32 template < - GMMA::Major tnspA, - GMMA::Major tnspB, GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, GMMA::SparseSel spsel = GMMA::SparseSel::Zero > -struct GMMA_64x176x32_F32F16F16_SS +struct GMMA_64x32x16_F32TF32TF32_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; using ERegisters = uint32_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = float[88]; + using CRegisters = float[16]; CUTE_HOST_DEVICE static void fma(uint64_t const& desc_a, @@ -3980,103 +3779,52 @@ struct GMMA_64x176x32_F32F16F16_SS float & d04, float & d05, float & d06, float & d07, float & d08, float & d09, float & d10, float & d11, float & d12, float & d13, float & d14, float & d15, - float & d16, float & d17, float & d18, float & d19, - float & d20, float & d21, float & d22, float & d23, - float & d24, float & d25, float & d26, float & d27, - float & d28, float & d29, float & d30, float & d31, - float & d32, float & d33, float & d34, float & d35, - float & d36, float & d37, float & d38, float & d39, - float & d40, float & d41, float & d42, float & d43, - float & d44, float & d45, float & d46, float & d47, - float & d48, float & d49, float & d50, float & d51, - float & d52, float & d53, float & d54, float & d55, - float & d56, float & d57, float & d58, float & d59, - float & d60, float & d61, float & d62, float & d63, - float & d64, float & d65, float & d66, float & d67, - float & d68, float & d69, float & d70, float & d71, - float & d72, float & d73, float & d74, float & d75, - float & d76, float & d77, float & d78, float & d79, - float & d80, float & d81, float & d82, float & d83, - float & d84, float & d85, float & d86, float & d87, uint32_t const& e, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %92, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n176k32.f32.f16.f16 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87}," - " %88," - " %89," - " %90, %91," - " p, %93, %94, %95, %96;\n" + "setp.ne.b32 p, %20, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n32k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + " %16," + " %17," + " %18, %19," + " p, %21, %22;\n" "}\n" : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), - "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), - "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), - "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), - "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), - "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), - "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), - "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), - "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), - "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), - "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), - "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), - "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), - "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), - "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), - "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), - "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), - "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), - "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), - "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87) + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15) : "l"(desc_a), "l"(desc_b), "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x176x32_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x32x16_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x176x32 F32+=F16*F16 +// SPARSE GMMA 64x32x16 TN F32+=TF32*TF32 template < - GMMA::Major tnspA, - GMMA::Major tnspB, GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, GMMA::SparseSel spsel = GMMA::SparseSel::Zero > -struct GMMA_64x176x32_F32F16F16_RS +struct GMMA_64x32x16_F32TF32TF32_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; using ERegisters = uint32_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = float[88]; - - static_assert(tnspA == GMMA::Major::K, - "Register source operand A must have K major layout."); + using CRegisters = float[16]; CUTE_HOST_DEVICE static void fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, @@ -4085,99 +3833,52 @@ struct GMMA_64x176x32_F32F16F16_RS float & d04, float & d05, float & d06, float & d07, float & d08, float & d09, float & d10, float & d11, float & d12, float & d13, float & d14, float & d15, - float & d16, float & d17, float & d18, float & d19, - float & d20, float & d21, float & d22, float & d23, - float & d24, float & d25, float & d26, float & d27, - float & d28, float & d29, float & d30, float & d31, - float & d32, float & d33, float & d34, float & d35, - float & d36, float & d37, float & d38, float & d39, - float & d40, float & d41, float & d42, float & d43, - float & d44, float & d45, float & d46, float & d47, - float & d48, float & d49, float & d50, float & d51, - float & d52, float & d53, float & d54, float & d55, - float & d56, float & d57, float & d58, float & d59, - float & d60, float & d61, float & d62, float & d63, - float & d64, float & d65, float & d66, float & d67, - float & d68, float & d69, float & d70, float & d71, - float & d72, float & d73, float & d74, float & d75, - float & d76, float & d77, float & d78, float & d79, - float & d80, float & d81, float & d82, float & d83, - float & d84, float & d85, float & d86, float & d87, uint32_t const& e, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %95, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n176k32.f32.f16.f16 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87}," - "{%88, %89, %90, %91}," - " %92," - " %93, %94," - " p, %96, %97, %98;\n" + "setp.ne.b32 p, %23, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n32k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + "{%16, %17, %18, %19}," + " %20," + " %21, %22," + " p, %24, %25;\n" "}\n" : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), - "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), - "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), - "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), - "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), - "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), - "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), - "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), - "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), - "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), - "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), - "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), - "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), - "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), - "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), - "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), - "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), - "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), - "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), - "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87) + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15) : "r"(a00), "r"(a01), "r"(a02), "r"(a03), "l"(desc_b), "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x176x32_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x32x16_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -// SPARSE GMMA 64x192x32 F32+=F16*F16 +// SPARSE GMMA 64x64x16 TN F32+=TF32*TF32 template < - GMMA::Major tnspA, - GMMA::Major tnspB, GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, GMMA::SparseSel spsel = GMMA::SparseSel::Zero > -struct GMMA_64x192x32_F32F16F16_SS +struct GMMA_64x64x16_F32TF32TF32_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; using ERegisters = uint32_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = float[96]; + using CRegisters = float[32]; CUTE_HOST_DEVICE static void fma(uint64_t const& desc_a, @@ -4190,47 +3891,24 @@ struct GMMA_64x192x32_F32F16F16_SS float & d20, float & d21, float & d22, float & d23, float & d24, float & d25, float & d26, float & d27, float & d28, float & d29, float & d30, float & d31, - float & d32, float & d33, float & d34, float & d35, - float & d36, float & d37, float & d38, float & d39, - float & d40, float & d41, float & d42, float & d43, - float & d44, float & d45, float & d46, float & d47, - float & d48, float & d49, float & d50, float & d51, - float & d52, float & d53, float & d54, float & d55, - float & d56, float & d57, float & d58, float & d59, - float & d60, float & d61, float & d62, float & d63, - float & d64, float & d65, float & d66, float & d67, - float & d68, float & d69, float & d70, float & d71, - float & d72, float & d73, float & d74, float & d75, - float & d76, float & d77, float & d78, float & d79, - float & d80, float & d81, float & d82, float & d83, - float & d84, float & d85, float & d86, float & d87, - float & d88, float & d89, float & d90, float & d91, - float & d92, float & d93, float & d94, float & d95, uint32_t const& e, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %100, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n192k32.f32.f16.f16 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87, " - " %88, %89, %90, %91, %92, %93, %94, %95}," - " %96," - " %97," - " %98, %99," - " p, %101, %102, %103, %104;\n" + "setp.ne.b32 p, %36, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n64k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + " %32," + " %33," + " %34, %35," + " p, %37, %38;\n" "}\n" : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), @@ -4239,53 +3917,32 @@ struct GMMA_64x192x32_F32F16F16_SS "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), - "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), - "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), - "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), - "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), - "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), - "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), - "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), - "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), - "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), - "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), - "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), - "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), - "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), - "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), - "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), - "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91), - "+f"(d92), "+f"(d93), "+f"(d94), "+f"(d95) + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31) : "l"(desc_a), "l"(desc_b), "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x192x32_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x64x16_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// SPARSE GMMA 64x192x32 F32+=F16*F16 +// SPARSE GMMA 64x64x16 TN F32+=TF32*TF32 template < - GMMA::Major tnspA, - GMMA::Major tnspB, GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, GMMA::SparseSel spsel = GMMA::SparseSel::Zero > -struct GMMA_64x192x32_F32F16F16_RS +struct GMMA_64x64x16_F32TF32TF32_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; using ERegisters = uint32_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = float[96]; - - static_assert(tnspA == GMMA::Major::K, - "Register source operand A must have K major layout."); + using CRegisters = float[32]; CUTE_HOST_DEVICE static void fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, @@ -4298,47 +3955,24 @@ struct GMMA_64x192x32_F32F16F16_RS float & d20, float & d21, float & d22, float & d23, float & d24, float & d25, float & d26, float & d27, float & d28, float & d29, float & d30, float & d31, - float & d32, float & d33, float & d34, float & d35, - float & d36, float & d37, float & d38, float & d39, - float & d40, float & d41, float & d42, float & d43, - float & d44, float & d45, float & d46, float & d47, - float & d48, float & d49, float & d50, float & d51, - float & d52, float & d53, float & d54, float & d55, - float & d56, float & d57, float & d58, float & d59, - float & d60, float & d61, float & d62, float & d63, - float & d64, float & d65, float & d66, float & d67, - float & d68, float & d69, float & d70, float & d71, - float & d72, float & d73, float & d74, float & d75, - float & d76, float & d77, float & d78, float & d79, - float & d80, float & d81, float & d82, float & d83, - float & d84, float & d85, float & d86, float & d87, - float & d88, float & d89, float & d90, float & d91, - float & d92, float & d93, float & d94, float & d95, uint32_t const& e, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %103, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n192k32.f32.f16.f16 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87, " - " %88, %89, %90, %91, %92, %93, %94, %95}," - "{%96, %97, %98, %99}," - " %100," - " %101, %102," - " p, %104, %105, %106;\n" + "setp.ne.b32 p, %39, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n64k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + "{%32, %33, %34, %35}," + " %36," + " %37, %38," + " p, %40, %41;\n" "}\n" : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), @@ -4347,319 +3981,210 @@ struct GMMA_64x192x32_F32F16F16_RS "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), - "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), - "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), - "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), - "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), - "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), - "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), - "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), - "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), - "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), - "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), - "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), - "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), - "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), - "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), - "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), - "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91), - "+f"(d92), "+f"(d93), "+f"(d94), "+f"(d95) + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31) : "r"(a00), "r"(a01), "r"(a02), "r"(a03), "l"(desc_b), "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x192x32_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x64x16_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x208x32 F32+=F16*F16 +// SPARSE GMMA 64x96x16 TN F32+=TF32*TF32 template < - GMMA::Major tnspA, - GMMA::Major tnspB, GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, GMMA::SparseSel spsel = GMMA::SparseSel::Zero > -struct GMMA_64x208x32_F32F16F16_SS +struct GMMA_64x96x16_F32TF32TF32_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; using ERegisters = uint32_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = float[104]; + using CRegisters = float[48]; CUTE_HOST_DEVICE static void fma(uint64_t const& desc_a, uint64_t const& desc_b, - float & d000, float & d001, float & d002, float & d003, - float & d004, float & d005, float & d006, float & d007, - float & d008, float & d009, float & d010, float & d011, - float & d012, float & d013, float & d014, float & d015, - float & d016, float & d017, float & d018, float & d019, - float & d020, float & d021, float & d022, float & d023, - float & d024, float & d025, float & d026, float & d027, - float & d028, float & d029, float & d030, float & d031, - float & d032, float & d033, float & d034, float & d035, - float & d036, float & d037, float & d038, float & d039, - float & d040, float & d041, float & d042, float & d043, - float & d044, float & d045, float & d046, float & d047, - float & d048, float & d049, float & d050, float & d051, - float & d052, float & d053, float & d054, float & d055, - float & d056, float & d057, float & d058, float & d059, - float & d060, float & d061, float & d062, float & d063, - float & d064, float & d065, float & d066, float & d067, - float & d068, float & d069, float & d070, float & d071, - float & d072, float & d073, float & d074, float & d075, - float & d076, float & d077, float & d078, float & d079, - float & d080, float & d081, float & d082, float & d083, - float & d084, float & d085, float & d086, float & d087, - float & d088, float & d089, float & d090, float & d091, - float & d092, float & d093, float & d094, float & d095, - float & d096, float & d097, float & d098, float & d099, - float & d100, float & d101, float & d102, float & d103, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, uint32_t const& e, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %108, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n208k32.f32.f16.f16 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87, " - " %88, %89, %90, %91, %92, %93, %94, %95, " - " %96, %97, %98, %99, %100, %101, %102, %103}," - " %104," - " %105," - " %106, %107," - " p, %109, %110, %111, %112;\n" + "setp.ne.b32 p, %52, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n96k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + " %48," + " %49," + " %50, %51," + " p, %53, %54;\n" "}\n" - : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), - "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), - "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), - "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), - "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), - "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), - "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), - "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), - "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), - "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), - "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), - "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), - "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), - "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), - "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), - "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), - "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), - "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), - "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), - "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), - "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), - "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), - "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), - "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), - "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), - "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103) + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47) : "l"(desc_a), "l"(desc_b), "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x208x32_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x96x16_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x208x32 F32+=F16*F16 +// SPARSE GMMA 64x96x16 TN F32+=TF32*TF32 template < - GMMA::Major tnspA, - GMMA::Major tnspB, GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, GMMA::SparseSel spsel = GMMA::SparseSel::Zero > -struct GMMA_64x208x32_F32F16F16_RS +struct GMMA_64x96x16_F32TF32TF32_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; using ERegisters = uint32_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = float[104]; - - static_assert(tnspA == GMMA::Major::K, - "Register source operand A must have K major layout."); + using CRegisters = float[48]; CUTE_HOST_DEVICE static void - fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, uint64_t const& desc_b, - float & d000, float & d001, float & d002, float & d003, - float & d004, float & d005, float & d006, float & d007, - float & d008, float & d009, float & d010, float & d011, - float & d012, float & d013, float & d014, float & d015, - float & d016, float & d017, float & d018, float & d019, - float & d020, float & d021, float & d022, float & d023, - float & d024, float & d025, float & d026, float & d027, - float & d028, float & d029, float & d030, float & d031, - float & d032, float & d033, float & d034, float & d035, - float & d036, float & d037, float & d038, float & d039, - float & d040, float & d041, float & d042, float & d043, - float & d044, float & d045, float & d046, float & d047, - float & d048, float & d049, float & d050, float & d051, - float & d052, float & d053, float & d054, float & d055, - float & d056, float & d057, float & d058, float & d059, - float & d060, float & d061, float & d062, float & d063, - float & d064, float & d065, float & d066, float & d067, - float & d068, float & d069, float & d070, float & d071, - float & d072, float & d073, float & d074, float & d075, - float & d076, float & d077, float & d078, float & d079, - float & d080, float & d081, float & d082, float & d083, - float & d084, float & d085, float & d086, float & d087, - float & d088, float & d089, float & d090, float & d091, - float & d092, float & d093, float & d094, float & d095, - float & d096, float & d097, float & d098, float & d099, - float & d100, float & d101, float & d102, float & d103, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, uint32_t const& e, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %111, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n208k32.f32.f16.f16 " + "setp.ne.b32 p, %55, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n96k16.f32.tf32.tf32 " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " " %16, %17, %18, %19, %20, %21, %22, %23, " " %24, %25, %26, %27, %28, %29, %30, %31, " " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87, " - " %88, %89, %90, %91, %92, %93, %94, %95, " - " %96, %97, %98, %99, %100, %101, %102, %103}," - "{%104, %105, %106, %107}," - " %108," - " %109, %110," - " p, %112, %113, %114;\n" + " %40, %41, %42, %43, %44, %45, %46, %47}," + "{%48, %49, %50, %51}," + " %52," + " %53, %54," + " p, %56, %57;\n" "}\n" - : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), - "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), - "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), - "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), - "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), - "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), - "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), - "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), - "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), - "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), - "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), - "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), - "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), - "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), - "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), - "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), - "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), - "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), - "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), - "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), - "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), - "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), - "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), - "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), - "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), - "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103) - : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), "l"(desc_b), "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x208x32_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x96x16_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x224x32 F32+=F16*F16 +// SPARSE GMMA 64x128x16 TN F32+=TF32*TF32 template < - GMMA::Major tnspA, - GMMA::Major tnspB, GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, GMMA::SparseSel spsel = GMMA::SparseSel::Zero > -struct GMMA_64x224x32_F32F16F16_SS +struct GMMA_64x128x16_F32TF32TF32_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; using ERegisters = uint32_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = float[112]; + using CRegisters = float[64]; CUTE_HOST_DEVICE static void fma(uint64_t const& desc_a, uint64_t const& desc_b, - float & d000, float & d001, float & d002, float & d003, - float & d004, float & d005, float & d006, float & d007, - float & d008, float & d009, float & d010, float & d011, - float & d012, float & d013, float & d014, float & d015, - float & d016, float & d017, float & d018, float & d019, - float & d020, float & d021, float & d022, float & d023, - float & d024, float & d025, float & d026, float & d027, - float & d028, float & d029, float & d030, float & d031, - float & d032, float & d033, float & d034, float & d035, - float & d036, float & d037, float & d038, float & d039, - float & d040, float & d041, float & d042, float & d043, - float & d044, float & d045, float & d046, float & d047, - float & d048, float & d049, float & d050, float & d051, - float & d052, float & d053, float & d054, float & d055, - float & d056, float & d057, float & d058, float & d059, - float & d060, float & d061, float & d062, float & d063, - float & d064, float & d065, float & d066, float & d067, - float & d068, float & d069, float & d070, float & d071, - float & d072, float & d073, float & d074, float & d075, - float & d076, float & d077, float & d078, float & d079, - float & d080, float & d081, float & d082, float & d083, - float & d084, float & d085, float & d086, float & d087, - float & d088, float & d089, float & d090, float & d091, - float & d092, float & d093, float & d094, float & d095, - float & d096, float & d097, float & d098, float & d099, - float & d100, float & d101, float & d102, float & d103, - float & d104, float & d105, float & d106, float & d107, - float & d108, float & d109, float & d110, float & d111, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, uint32_t const& e, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %116, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n224k32.f32.f16.f16 " + "setp.ne.b32 p, %68, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n128k16.f32.tf32.tf32 " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " " %16, %17, %18, %19, %20, %21, %22, %23, " @@ -4667,119 +4192,83 @@ struct GMMA_64x224x32_F32F16F16_SS " %32, %33, %34, %35, %36, %37, %38, %39, " " %40, %41, %42, %43, %44, %45, %46, %47, " " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87, " - " %88, %89, %90, %91, %92, %93, %94, %95, " - " %96, %97, %98, %99, %100, %101, %102, %103, " - " %104, %105, %106, %107, %108, %109, %110, %111}," - " %112," - " %113," - " %114, %115," - " p, %117, %118, %119, %120;\n" + " %56, %57, %58, %59, %60, %61, %62, %63}," + " %64," + " %65," + " %66, %67," + " p, %69, %70;\n" "}\n" - : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), - "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), - "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), - "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), - "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), - "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), - "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), - "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), - "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), - "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), - "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), - "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), - "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), - "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), - "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), - "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), - "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), - "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), - "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), - "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), - "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), - "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), - "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), - "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), - "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), - "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), - "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), - "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111) + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63) : "l"(desc_a), "l"(desc_b), "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x224x32_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x128x16_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x224x32 F32+=F16*F16 +// SPARSE GMMA 64x128x16 TN F32+=TF32*TF32 template < - GMMA::Major tnspA, - GMMA::Major tnspB, GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, GMMA::SparseSel spsel = GMMA::SparseSel::Zero > -struct GMMA_64x224x32_F32F16F16_RS +struct GMMA_64x128x16_F32TF32TF32_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; using ERegisters = uint32_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = float[112]; - - static_assert(tnspA == GMMA::Major::K, - "Register source operand A must have K major layout."); + using CRegisters = float[64]; CUTE_HOST_DEVICE static void - fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, uint64_t const& desc_b, - float & d000, float & d001, float & d002, float & d003, - float & d004, float & d005, float & d006, float & d007, - float & d008, float & d009, float & d010, float & d011, - float & d012, float & d013, float & d014, float & d015, - float & d016, float & d017, float & d018, float & d019, - float & d020, float & d021, float & d022, float & d023, - float & d024, float & d025, float & d026, float & d027, - float & d028, float & d029, float & d030, float & d031, - float & d032, float & d033, float & d034, float & d035, - float & d036, float & d037, float & d038, float & d039, - float & d040, float & d041, float & d042, float & d043, - float & d044, float & d045, float & d046, float & d047, - float & d048, float & d049, float & d050, float & d051, - float & d052, float & d053, float & d054, float & d055, - float & d056, float & d057, float & d058, float & d059, - float & d060, float & d061, float & d062, float & d063, - float & d064, float & d065, float & d066, float & d067, - float & d068, float & d069, float & d070, float & d071, - float & d072, float & d073, float & d074, float & d075, - float & d076, float & d077, float & d078, float & d079, - float & d080, float & d081, float & d082, float & d083, - float & d084, float & d085, float & d086, float & d087, - float & d088, float & d089, float & d090, float & d091, - float & d092, float & d093, float & d094, float & d095, - float & d096, float & d097, float & d098, float & d099, - float & d100, float & d101, float & d102, float & d103, - float & d104, float & d105, float & d106, float & d107, - float & d108, float & d109, float & d110, float & d111, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, uint32_t const& e, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %119, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n224k32.f32.f16.f16 " + "setp.ne.b32 p, %71, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n128k16.f32.tf32.tf32 " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " " %16, %17, %18, %19, %20, %21, %22, %23, " @@ -4787,366 +4276,91 @@ struct GMMA_64x224x32_F32F16F16_RS " %32, %33, %34, %35, %36, %37, %38, %39, " " %40, %41, %42, %43, %44, %45, %46, %47, " " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87, " - " %88, %89, %90, %91, %92, %93, %94, %95, " - " %96, %97, %98, %99, %100, %101, %102, %103, " - " %104, %105, %106, %107, %108, %109, %110, %111}," - "{%112, %113, %114, %115}," - " %116," - " %117, %118," - " p, %120, %121, %122;\n" + " %56, %57, %58, %59, %60, %61, %62, %63}," + "{%64, %65, %66, %67}," + " %68," + " %69, %70," + " p, %72, %73;\n" "}\n" - : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), - "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), - "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), - "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), - "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), - "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), - "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), - "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), - "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), - "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), - "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), - "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), - "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), - "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), - "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), - "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), - "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), - "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), - "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), - "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), - "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), - "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), - "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), - "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), - "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), - "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), - "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), - "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111) - : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), "l"(desc_b), "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x224x32_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x128x16_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x240x32 F32+=F16*F16 +// SPARSE GMMA 64x192x16 TN F32+=TF32*TF32 template < - GMMA::Major tnspA, - GMMA::Major tnspB, GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, GMMA::SparseSel spsel = GMMA::SparseSel::Zero > -struct GMMA_64x240x32_F32F16F16_SS +struct GMMA_64x192x16_F32TF32TF32_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; using ERegisters = uint32_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = float[120]; + using CRegisters = float[96]; CUTE_HOST_DEVICE static void fma(uint64_t const& desc_a, uint64_t const& desc_b, - float & d000, float & d001, float & d002, float & d003, - float & d004, float & d005, float & d006, float & d007, - float & d008, float & d009, float & d010, float & d011, - float & d012, float & d013, float & d014, float & d015, - float & d016, float & d017, float & d018, float & d019, - float & d020, float & d021, float & d022, float & d023, - float & d024, float & d025, float & d026, float & d027, - float & d028, float & d029, float & d030, float & d031, - float & d032, float & d033, float & d034, float & d035, - float & d036, float & d037, float & d038, float & d039, - float & d040, float & d041, float & d042, float & d043, - float & d044, float & d045, float & d046, float & d047, - float & d048, float & d049, float & d050, float & d051, - float & d052, float & d053, float & d054, float & d055, - float & d056, float & d057, float & d058, float & d059, - float & d060, float & d061, float & d062, float & d063, - float & d064, float & d065, float & d066, float & d067, - float & d068, float & d069, float & d070, float & d071, - float & d072, float & d073, float & d074, float & d075, - float & d076, float & d077, float & d078, float & d079, - float & d080, float & d081, float & d082, float & d083, - float & d084, float & d085, float & d086, float & d087, - float & d088, float & d089, float & d090, float & d091, - float & d092, float & d093, float & d094, float & d095, - float & d096, float & d097, float & d098, float & d099, - float & d100, float & d101, float & d102, float & d103, - float & d104, float & d105, float & d106, float & d107, - float & d108, float & d109, float & d110, float & d111, - float & d112, float & d113, float & d114, float & d115, - float & d116, float & d117, float & d118, float & d119, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %124, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n240k32.f32.f16.f16 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87, " - " %88, %89, %90, %91, %92, %93, %94, %95, " - " %96, %97, %98, %99, %100, %101, %102, %103, " - " %104, %105, %106, %107, %108, %109, %110, %111, " - " %112, %113, %114, %115, %116, %117, %118, %119}," - " %120," - " %121," - " %122, %123," - " p, %125, %126, %127, %128;\n" - "}\n" - : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), - "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), - "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), - "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), - "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), - "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), - "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), - "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), - "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), - "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), - "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), - "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), - "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), - "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), - "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), - "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), - "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), - "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), - "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), - "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), - "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), - "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), - "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), - "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), - "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), - "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), - "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), - "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), - "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), - "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119) - : "l"(desc_a), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x240x32_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x240x32 F32+=F16*F16 -template < - GMMA::Major tnspA, - GMMA::Major tnspB, - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x240x32_F32F16F16_RS -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = float[120]; - - static_assert(tnspA == GMMA::Major::K, - "Register source operand A must have K major layout."); - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, - uint64_t const& desc_b, - float & d000, float & d001, float & d002, float & d003, - float & d004, float & d005, float & d006, float & d007, - float & d008, float & d009, float & d010, float & d011, - float & d012, float & d013, float & d014, float & d015, - float & d016, float & d017, float & d018, float & d019, - float & d020, float & d021, float & d022, float & d023, - float & d024, float & d025, float & d026, float & d027, - float & d028, float & d029, float & d030, float & d031, - float & d032, float & d033, float & d034, float & d035, - float & d036, float & d037, float & d038, float & d039, - float & d040, float & d041, float & d042, float & d043, - float & d044, float & d045, float & d046, float & d047, - float & d048, float & d049, float & d050, float & d051, - float & d052, float & d053, float & d054, float & d055, - float & d056, float & d057, float & d058, float & d059, - float & d060, float & d061, float & d062, float & d063, - float & d064, float & d065, float & d066, float & d067, - float & d068, float & d069, float & d070, float & d071, - float & d072, float & d073, float & d074, float & d075, - float & d076, float & d077, float & d078, float & d079, - float & d080, float & d081, float & d082, float & d083, - float & d084, float & d085, float & d086, float & d087, - float & d088, float & d089, float & d090, float & d091, - float & d092, float & d093, float & d094, float & d095, - float & d096, float & d097, float & d098, float & d099, - float & d100, float & d101, float & d102, float & d103, - float & d104, float & d105, float & d106, float & d107, - float & d108, float & d109, float & d110, float & d111, - float & d112, float & d113, float & d114, float & d115, - float & d116, float & d117, float & d118, float & d119, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %127, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n240k32.f32.f16.f16 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87, " - " %88, %89, %90, %91, %92, %93, %94, %95, " - " %96, %97, %98, %99, %100, %101, %102, %103, " - " %104, %105, %106, %107, %108, %109, %110, %111, " - " %112, %113, %114, %115, %116, %117, %118, %119}," - "{%120, %121, %122, %123}," - " %124," - " %125, %126," - " p, %128, %129, %130;\n" - "}\n" - : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), - "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), - "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), - "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), - "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), - "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), - "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), - "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), - "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), - "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), - "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), - "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), - "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), - "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), - "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), - "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), - "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), - "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), - "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), - "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), - "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), - "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), - "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), - "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), - "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), - "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), - "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), - "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), - "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), - "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119) - : "r"(a000), "r"(a001), "r"(a002), "r"(a003), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x240x32_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// SPARSE GMMA 64x256x32 F32+=F16*F16 -template < - GMMA::Major tnspA, - GMMA::Major tnspB, - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x256x32_F32F16F16_SS -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = float[128]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - float & d000, float & d001, float & d002, float & d003, - float & d004, float & d005, float & d006, float & d007, - float & d008, float & d009, float & d010, float & d011, - float & d012, float & d013, float & d014, float & d015, - float & d016, float & d017, float & d018, float & d019, - float & d020, float & d021, float & d022, float & d023, - float & d024, float & d025, float & d026, float & d027, - float & d028, float & d029, float & d030, float & d031, - float & d032, float & d033, float & d034, float & d035, - float & d036, float & d037, float & d038, float & d039, - float & d040, float & d041, float & d042, float & d043, - float & d044, float & d045, float & d046, float & d047, - float & d048, float & d049, float & d050, float & d051, - float & d052, float & d053, float & d054, float & d055, - float & d056, float & d057, float & d058, float & d059, - float & d060, float & d061, float & d062, float & d063, - float & d064, float & d065, float & d066, float & d067, - float & d068, float & d069, float & d070, float & d071, - float & d072, float & d073, float & d074, float & d075, - float & d076, float & d077, float & d078, float & d079, - float & d080, float & d081, float & d082, float & d083, - float & d084, float & d085, float & d086, float & d087, - float & d088, float & d089, float & d090, float & d091, - float & d092, float & d093, float & d094, float & d095, - float & d096, float & d097, float & d098, float & d099, - float & d100, float & d101, float & d102, float & d103, - float & d104, float & d105, float & d106, float & d107, - float & d108, float & d109, float & d110, float & d111, - float & d112, float & d113, float & d114, float & d115, - float & d116, float & d117, float & d118, float & d119, - float & d120, float & d121, float & d122, float & d123, - float & d124, float & d125, float & d126, float & d127, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + float & d88, float & d89, float & d90, float & d91, + float & d92, float & d93, float & d94, float & d95, uint32_t const& e, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %132, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n256k32.f32.f16.f16 " + "setp.ne.b32 p, %100, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n192k16.f32.tf32.tf32 " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " " %16, %17, %18, %19, %20, %21, %22, %23, " @@ -5158,123 +4372,99 @@ struct GMMA_64x256x32_F32F16F16_SS " %64, %65, %66, %67, %68, %69, %70, %71, " " %72, %73, %74, %75, %76, %77, %78, %79, " " %80, %81, %82, %83, %84, %85, %86, %87, " - " %88, %89, %90, %91, %92, %93, %94, %95, " - " %96, %97, %98, %99, %100, %101, %102, %103, " - " %104, %105, %106, %107, %108, %109, %110, %111, " - " %112, %113, %114, %115, %116, %117, %118, %119, " - " %120, %121, %122, %123, %124, %125, %126, %127}," - " %128," - " %129," - " %130, %131," - " p, %133, %134, %135, %136;\n" + " %88, %89, %90, %91, %92, %93, %94, %95}," + " %96," + " %97," + " %98, %99," + " p, %101, %102;\n" "}\n" - : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), - "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), - "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), - "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), - "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), - "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), - "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), - "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), - "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), - "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), - "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), - "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), - "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), - "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), - "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), - "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), - "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), - "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), - "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), - "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), - "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), - "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), - "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), - "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), - "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), - "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), - "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), - "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), - "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), - "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), - "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123), - "+f"(d124), "+f"(d125), "+f"(d126), "+f"(d127) + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), + "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91), + "+f"(d92), "+f"(d93), "+f"(d94), "+f"(d95) : "l"(desc_a), "l"(desc_b), "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x256x32_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x192x16_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// SPARSE GMMA 64x256x32 F32+=F16*F16 +// SPARSE GMMA 64x192x16 TN F32+=TF32*TF32 template < - GMMA::Major tnspA, - GMMA::Major tnspB, GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, GMMA::SparseSel spsel = GMMA::SparseSel::Zero > -struct GMMA_64x256x32_F32F16F16_RS +struct GMMA_64x192x16_F32TF32TF32_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; using ERegisters = uint32_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = float[128]; - - static_assert(tnspA == GMMA::Major::K, - "Register source operand A must have K major layout."); + using CRegisters = float[96]; CUTE_HOST_DEVICE static void - fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, uint64_t const& desc_b, - float & d000, float & d001, float & d002, float & d003, - float & d004, float & d005, float & d006, float & d007, - float & d008, float & d009, float & d010, float & d011, - float & d012, float & d013, float & d014, float & d015, - float & d016, float & d017, float & d018, float & d019, - float & d020, float & d021, float & d022, float & d023, - float & d024, float & d025, float & d026, float & d027, - float & d028, float & d029, float & d030, float & d031, - float & d032, float & d033, float & d034, float & d035, - float & d036, float & d037, float & d038, float & d039, - float & d040, float & d041, float & d042, float & d043, - float & d044, float & d045, float & d046, float & d047, - float & d048, float & d049, float & d050, float & d051, - float & d052, float & d053, float & d054, float & d055, - float & d056, float & d057, float & d058, float & d059, - float & d060, float & d061, float & d062, float & d063, - float & d064, float & d065, float & d066, float & d067, - float & d068, float & d069, float & d070, float & d071, - float & d072, float & d073, float & d074, float & d075, - float & d076, float & d077, float & d078, float & d079, - float & d080, float & d081, float & d082, float & d083, - float & d084, float & d085, float & d086, float & d087, - float & d088, float & d089, float & d090, float & d091, - float & d092, float & d093, float & d094, float & d095, - float & d096, float & d097, float & d098, float & d099, - float & d100, float & d101, float & d102, float & d103, - float & d104, float & d105, float & d106, float & d107, - float & d108, float & d109, float & d110, float & d111, - float & d112, float & d113, float & d114, float & d115, - float & d116, float & d117, float & d118, float & d119, - float & d120, float & d121, float & d122, float & d123, - float & d124, float & d125, float & d126, float & d127, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + float & d88, float & d89, float & d90, float & d91, + float & d92, float & d93, float & d94, float & d95, uint32_t const& e, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %135, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n256k32.f32.f16.f16 " + "setp.ne.b32 p, %103, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n192k16.f32.tf32.tf32 " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " " %16, %17, %18, %19, %20, %21, %22, %23, " @@ -5286,1725 +4476,11 @@ struct GMMA_64x256x32_F32F16F16_RS " %64, %65, %66, %67, %68, %69, %70, %71, " " %72, %73, %74, %75, %76, %77, %78, %79, " " %80, %81, %82, %83, %84, %85, %86, %87, " - " %88, %89, %90, %91, %92, %93, %94, %95, " - " %96, %97, %98, %99, %100, %101, %102, %103, " - " %104, %105, %106, %107, %108, %109, %110, %111, " - " %112, %113, %114, %115, %116, %117, %118, %119, " - " %120, %121, %122, %123, %124, %125, %126, %127}," - "{%128, %129, %130, %131}," - " %132," - " %133, %134," - " p, %136, %137, %138;\n" - "}\n" - : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), - "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), - "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), - "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), - "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), - "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), - "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), - "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), - "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), - "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), - "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), - "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), - "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), - "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), - "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), - "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), - "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), - "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), - "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), - "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), - "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), - "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), - "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), - "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), - "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), - "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), - "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), - "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), - "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), - "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), - "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123), - "+f"(d124), "+f"(d125), "+f"(d126), "+f"(d127) - : "r"(a000), "r"(a001), "r"(a002), "r"(a003), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x256x32_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// SPARSE GMMA 64x8x32 F32+=BF16*BF16 -template < - GMMA::Major tnspA, - GMMA::Major tnspB, - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x8x32_F32BF16BF16_SS -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = float[4]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - float & d0, float & d1, float & d2, float & d3, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %8, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n8k32.f32.bf16.bf16 " - "{%0, %1, %2, %3}," - " %4," - " %5," - " %6, %7," - " p, %9, %10, %11, %12;\n" - "}\n" - : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3) - : "l"(desc_a), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x8x32_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// SPARSE GMMA 64x8x32 F32+=BF16*BF16 -template < - GMMA::Major tnspA, - GMMA::Major tnspB, - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x8x32_F32BF16BF16_RS -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = float[4]; - - static_assert(tnspA == GMMA::Major::K, - "Register source operand A must have K major layout."); - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, - uint64_t const& desc_b, - float & d0, float & d1, float & d2, float & d3, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %11, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n8k32.f32.bf16.bf16 " - "{%0, %1, %2, %3}," - "{%4, %5, %6, %7}," - " %8," - " %9, %10," - " p, %12, %13, %14;\n" - "}\n" - : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3) - : "r"(a0), "r"(a1), "r"(a2), "r"(a3), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x8x32_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// SPARSE GMMA 64x16x32 F32+=BF16*BF16 -template < - GMMA::Major tnspA, - GMMA::Major tnspB, - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x16x32_F32BF16BF16_SS -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = float[8]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - float & d0, float & d1, float & d2, float & d3, - float & d4, float & d5, float & d6, float & d7, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %12, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n16k32.f32.bf16.bf16 " - "{%0, %1, %2, %3, %4, %5, %6, %7}," - " %8," - " %9," - " %10, %11," - " p, %13, %14, %15, %16;\n" - "}\n" - : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3), - "+f"(d4), "+f"(d5), "+f"(d6), "+f"(d7) - : "l"(desc_a), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x16x32_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// SPARSE GMMA 64x16x32 F32+=BF16*BF16 -template < - GMMA::Major tnspA, - GMMA::Major tnspB, - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x16x32_F32BF16BF16_RS -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = float[8]; - - static_assert(tnspA == GMMA::Major::K, - "Register source operand A must have K major layout."); - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, - uint64_t const& desc_b, - float & d0, float & d1, float & d2, float & d3, - float & d4, float & d5, float & d6, float & d7, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %15, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n16k32.f32.bf16.bf16 " - "{%0, %1, %2, %3, %4, %5, %6, %7}," - "{%8, %9, %10, %11}," - " %12," - " %13, %14," - " p, %16, %17, %18;\n" - "}\n" - : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3), - "+f"(d4), "+f"(d5), "+f"(d6), "+f"(d7) - : "r"(a0), "r"(a1), "r"(a2), "r"(a3), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x16x32_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// SPARSE GMMA 64x32x32 F32+=BF16*BF16 -template < - GMMA::Major tnspA, - GMMA::Major tnspB, - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x32x32_F32BF16BF16_SS -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = float[16]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - float & d00, float & d01, float & d02, float & d03, - float & d04, float & d05, float & d06, float & d07, - float & d08, float & d09, float & d10, float & d11, - float & d12, float & d13, float & d14, float & d15, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %20, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n32k32.f32.bf16.bf16 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15}," - " %16," - " %17," - " %18, %19," - " p, %21, %22, %23, %24;\n" - "}\n" - : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), - "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), - "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), - "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15) - : "l"(desc_a), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x32x32_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// SPARSE GMMA 64x32x32 F32+=BF16*BF16 -template < - GMMA::Major tnspA, - GMMA::Major tnspB, - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x32x32_F32BF16BF16_RS -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = float[16]; - - static_assert(tnspA == GMMA::Major::K, - "Register source operand A must have K major layout."); - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, - uint64_t const& desc_b, - float & d00, float & d01, float & d02, float & d03, - float & d04, float & d05, float & d06, float & d07, - float & d08, float & d09, float & d10, float & d11, - float & d12, float & d13, float & d14, float & d15, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %23, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n32k32.f32.bf16.bf16 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15}," - "{%16, %17, %18, %19}," - " %20," - " %21, %22," - " p, %24, %25, %26;\n" - "}\n" - : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), - "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), - "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), - "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x32x32_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x48x32 F32+=BF16*BF16 -template < - GMMA::Major tnspA, - GMMA::Major tnspB, - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x48x32_F32BF16BF16_SS -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = float[24]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - float & d00, float & d01, float & d02, float & d03, - float & d04, float & d05, float & d06, float & d07, - float & d08, float & d09, float & d10, float & d11, - float & d12, float & d13, float & d14, float & d15, - float & d16, float & d17, float & d18, float & d19, - float & d20, float & d21, float & d22, float & d23, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %28, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n48k32.f32.bf16.bf16 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23}," - " %24," - " %25," - " %26, %27," - " p, %29, %30, %31, %32;\n" - "}\n" - : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), - "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), - "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), - "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), - "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), - "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23) - : "l"(desc_a), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x48x32_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x48x32 F32+=BF16*BF16 -template < - GMMA::Major tnspA, - GMMA::Major tnspB, - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x48x32_F32BF16BF16_RS -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = float[24]; - - static_assert(tnspA == GMMA::Major::K, - "Register source operand A must have K major layout."); - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, - uint64_t const& desc_b, - float & d00, float & d01, float & d02, float & d03, - float & d04, float & d05, float & d06, float & d07, - float & d08, float & d09, float & d10, float & d11, - float & d12, float & d13, float & d14, float & d15, - float & d16, float & d17, float & d18, float & d19, - float & d20, float & d21, float & d22, float & d23, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %31, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n48k32.f32.bf16.bf16 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23}," - "{%24, %25, %26, %27}," - " %28," - " %29, %30," - " p, %32, %33, %34;\n" - "}\n" - : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), - "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), - "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), - "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), - "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), - "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x48x32_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// SPARSE GMMA 64x64x32 F32+=BF16*BF16 -template < - GMMA::Major tnspA, - GMMA::Major tnspB, - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x64x32_F32BF16BF16_SS -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = float[32]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - float & d00, float & d01, float & d02, float & d03, - float & d04, float & d05, float & d06, float & d07, - float & d08, float & d09, float & d10, float & d11, - float & d12, float & d13, float & d14, float & d15, - float & d16, float & d17, float & d18, float & d19, - float & d20, float & d21, float & d22, float & d23, - float & d24, float & d25, float & d26, float & d27, - float & d28, float & d29, float & d30, float & d31, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %36, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n64k32.f32.bf16.bf16 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31}," - " %32," - " %33," - " %34, %35," - " p, %37, %38, %39, %40;\n" - "}\n" - : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), - "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), - "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), - "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), - "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), - "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), - "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), - "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31) - : "l"(desc_a), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x64x32_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// SPARSE GMMA 64x64x32 F32+=BF16*BF16 -template < - GMMA::Major tnspA, - GMMA::Major tnspB, - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x64x32_F32BF16BF16_RS -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = float[32]; - - static_assert(tnspA == GMMA::Major::K, - "Register source operand A must have K major layout."); - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, - uint64_t const& desc_b, - float & d00, float & d01, float & d02, float & d03, - float & d04, float & d05, float & d06, float & d07, - float & d08, float & d09, float & d10, float & d11, - float & d12, float & d13, float & d14, float & d15, - float & d16, float & d17, float & d18, float & d19, - float & d20, float & d21, float & d22, float & d23, - float & d24, float & d25, float & d26, float & d27, - float & d28, float & d29, float & d30, float & d31, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %39, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n64k32.f32.bf16.bf16 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31}," - "{%32, %33, %34, %35}," - " %36," - " %37, %38," - " p, %40, %41, %42;\n" - "}\n" - : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), - "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), - "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), - "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), - "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), - "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), - "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), - "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x64x32_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x80x32 F32+=BF16*BF16 -template < - GMMA::Major tnspA, - GMMA::Major tnspB, - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x80x32_F32BF16BF16_SS -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = float[40]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - float & d00, float & d01, float & d02, float & d03, - float & d04, float & d05, float & d06, float & d07, - float & d08, float & d09, float & d10, float & d11, - float & d12, float & d13, float & d14, float & d15, - float & d16, float & d17, float & d18, float & d19, - float & d20, float & d21, float & d22, float & d23, - float & d24, float & d25, float & d26, float & d27, - float & d28, float & d29, float & d30, float & d31, - float & d32, float & d33, float & d34, float & d35, - float & d36, float & d37, float & d38, float & d39, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %44, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n80k32.f32.bf16.bf16 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39}," - " %40," - " %41," - " %42, %43," - " p, %45, %46, %47, %48;\n" - "}\n" - : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), - "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), - "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), - "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), - "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), - "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), - "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), - "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), - "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), - "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39) - : "l"(desc_a), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x80x32_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x80x32 F32+=BF16*BF16 -template < - GMMA::Major tnspA, - GMMA::Major tnspB, - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x80x32_F32BF16BF16_RS -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = float[40]; - - static_assert(tnspA == GMMA::Major::K, - "Register source operand A must have K major layout."); - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, - uint64_t const& desc_b, - float & d00, float & d01, float & d02, float & d03, - float & d04, float & d05, float & d06, float & d07, - float & d08, float & d09, float & d10, float & d11, - float & d12, float & d13, float & d14, float & d15, - float & d16, float & d17, float & d18, float & d19, - float & d20, float & d21, float & d22, float & d23, - float & d24, float & d25, float & d26, float & d27, - float & d28, float & d29, float & d30, float & d31, - float & d32, float & d33, float & d34, float & d35, - float & d36, float & d37, float & d38, float & d39, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %47, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n80k32.f32.bf16.bf16 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39}," - "{%40, %41, %42, %43}," - " %44," - " %45, %46," - " p, %48, %49, %50;\n" - "}\n" - : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), - "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), - "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), - "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), - "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), - "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), - "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), - "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), - "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), - "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x80x32_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// SPARSE GMMA 64x96x32 F32+=BF16*BF16 -template < - GMMA::Major tnspA, - GMMA::Major tnspB, - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x96x32_F32BF16BF16_SS -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = float[48]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - float & d00, float & d01, float & d02, float & d03, - float & d04, float & d05, float & d06, float & d07, - float & d08, float & d09, float & d10, float & d11, - float & d12, float & d13, float & d14, float & d15, - float & d16, float & d17, float & d18, float & d19, - float & d20, float & d21, float & d22, float & d23, - float & d24, float & d25, float & d26, float & d27, - float & d28, float & d29, float & d30, float & d31, - float & d32, float & d33, float & d34, float & d35, - float & d36, float & d37, float & d38, float & d39, - float & d40, float & d41, float & d42, float & d43, - float & d44, float & d45, float & d46, float & d47, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %52, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n96k32.f32.bf16.bf16 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47}," - " %48," - " %49," - " %50, %51," - " p, %53, %54, %55, %56;\n" - "}\n" - : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), - "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), - "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), - "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), - "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), - "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), - "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), - "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), - "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), - "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), - "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), - "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47) - : "l"(desc_a), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x96x32_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// SPARSE GMMA 64x96x32 F32+=BF16*BF16 -template < - GMMA::Major tnspA, - GMMA::Major tnspB, - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x96x32_F32BF16BF16_RS -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = float[48]; - - static_assert(tnspA == GMMA::Major::K, - "Register source operand A must have K major layout."); - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, - uint64_t const& desc_b, - float & d00, float & d01, float & d02, float & d03, - float & d04, float & d05, float & d06, float & d07, - float & d08, float & d09, float & d10, float & d11, - float & d12, float & d13, float & d14, float & d15, - float & d16, float & d17, float & d18, float & d19, - float & d20, float & d21, float & d22, float & d23, - float & d24, float & d25, float & d26, float & d27, - float & d28, float & d29, float & d30, float & d31, - float & d32, float & d33, float & d34, float & d35, - float & d36, float & d37, float & d38, float & d39, - float & d40, float & d41, float & d42, float & d43, - float & d44, float & d45, float & d46, float & d47, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %55, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n96k32.f32.bf16.bf16 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47}," - "{%48, %49, %50, %51}," - " %52," - " %53, %54," - " p, %56, %57, %58;\n" - "}\n" - : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), - "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), - "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), - "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), - "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), - "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), - "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), - "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), - "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), - "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), - "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), - "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x96x32_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x112x32 F32+=BF16*BF16 -template < - GMMA::Major tnspA, - GMMA::Major tnspB, - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x112x32_F32BF16BF16_SS -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = float[56]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - float & d00, float & d01, float & d02, float & d03, - float & d04, float & d05, float & d06, float & d07, - float & d08, float & d09, float & d10, float & d11, - float & d12, float & d13, float & d14, float & d15, - float & d16, float & d17, float & d18, float & d19, - float & d20, float & d21, float & d22, float & d23, - float & d24, float & d25, float & d26, float & d27, - float & d28, float & d29, float & d30, float & d31, - float & d32, float & d33, float & d34, float & d35, - float & d36, float & d37, float & d38, float & d39, - float & d40, float & d41, float & d42, float & d43, - float & d44, float & d45, float & d46, float & d47, - float & d48, float & d49, float & d50, float & d51, - float & d52, float & d53, float & d54, float & d55, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %60, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n112k32.f32.bf16.bf16 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55}," - " %56," - " %57," - " %58, %59," - " p, %61, %62, %63, %64;\n" - "}\n" - : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), - "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), - "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), - "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), - "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), - "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), - "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), - "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), - "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), - "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), - "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), - "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), - "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), - "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55) - : "l"(desc_a), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x112x32_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x112x32 F32+=BF16*BF16 -template < - GMMA::Major tnspA, - GMMA::Major tnspB, - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x112x32_F32BF16BF16_RS -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = float[56]; - - static_assert(tnspA == GMMA::Major::K, - "Register source operand A must have K major layout."); - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, - uint64_t const& desc_b, - float & d00, float & d01, float & d02, float & d03, - float & d04, float & d05, float & d06, float & d07, - float & d08, float & d09, float & d10, float & d11, - float & d12, float & d13, float & d14, float & d15, - float & d16, float & d17, float & d18, float & d19, - float & d20, float & d21, float & d22, float & d23, - float & d24, float & d25, float & d26, float & d27, - float & d28, float & d29, float & d30, float & d31, - float & d32, float & d33, float & d34, float & d35, - float & d36, float & d37, float & d38, float & d39, - float & d40, float & d41, float & d42, float & d43, - float & d44, float & d45, float & d46, float & d47, - float & d48, float & d49, float & d50, float & d51, - float & d52, float & d53, float & d54, float & d55, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %63, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n112k32.f32.bf16.bf16 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55}," - "{%56, %57, %58, %59}," - " %60," - " %61, %62," - " p, %64, %65, %66;\n" - "}\n" - : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), - "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), - "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), - "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), - "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), - "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), - "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), - "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), - "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), - "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), - "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), - "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), - "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), - "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x112x32_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// SPARSE GMMA 64x128x32 F32+=BF16*BF16 -template < - GMMA::Major tnspA, - GMMA::Major tnspB, - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x128x32_F32BF16BF16_SS -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = float[64]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - float & d00, float & d01, float & d02, float & d03, - float & d04, float & d05, float & d06, float & d07, - float & d08, float & d09, float & d10, float & d11, - float & d12, float & d13, float & d14, float & d15, - float & d16, float & d17, float & d18, float & d19, - float & d20, float & d21, float & d22, float & d23, - float & d24, float & d25, float & d26, float & d27, - float & d28, float & d29, float & d30, float & d31, - float & d32, float & d33, float & d34, float & d35, - float & d36, float & d37, float & d38, float & d39, - float & d40, float & d41, float & d42, float & d43, - float & d44, float & d45, float & d46, float & d47, - float & d48, float & d49, float & d50, float & d51, - float & d52, float & d53, float & d54, float & d55, - float & d56, float & d57, float & d58, float & d59, - float & d60, float & d61, float & d62, float & d63, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %68, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n128k32.f32.bf16.bf16 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63}," - " %64," - " %65," - " %66, %67," - " p, %69, %70, %71, %72;\n" - "}\n" - : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), - "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), - "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), - "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), - "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), - "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), - "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), - "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), - "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), - "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), - "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), - "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), - "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), - "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), - "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), - "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63) - : "l"(desc_a), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x128x32_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// SPARSE GMMA 64x128x32 F32+=BF16*BF16 -template < - GMMA::Major tnspA, - GMMA::Major tnspB, - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x128x32_F32BF16BF16_RS -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = float[64]; - - static_assert(tnspA == GMMA::Major::K, - "Register source operand A must have K major layout."); - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, - uint64_t const& desc_b, - float & d00, float & d01, float & d02, float & d03, - float & d04, float & d05, float & d06, float & d07, - float & d08, float & d09, float & d10, float & d11, - float & d12, float & d13, float & d14, float & d15, - float & d16, float & d17, float & d18, float & d19, - float & d20, float & d21, float & d22, float & d23, - float & d24, float & d25, float & d26, float & d27, - float & d28, float & d29, float & d30, float & d31, - float & d32, float & d33, float & d34, float & d35, - float & d36, float & d37, float & d38, float & d39, - float & d40, float & d41, float & d42, float & d43, - float & d44, float & d45, float & d46, float & d47, - float & d48, float & d49, float & d50, float & d51, - float & d52, float & d53, float & d54, float & d55, - float & d56, float & d57, float & d58, float & d59, - float & d60, float & d61, float & d62, float & d63, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %71, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n128k32.f32.bf16.bf16 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63}," - "{%64, %65, %66, %67}," - " %68," - " %69, %70," - " p, %72, %73, %74;\n" - "}\n" - : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), - "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), - "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), - "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), - "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), - "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), - "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), - "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), - "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), - "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), - "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), - "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), - "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), - "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), - "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), - "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x128x32_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x144x32 F32+=BF16*BF16 -template < - GMMA::Major tnspA, - GMMA::Major tnspB, - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x144x32_F32BF16BF16_SS -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = float[72]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - float & d00, float & d01, float & d02, float & d03, - float & d04, float & d05, float & d06, float & d07, - float & d08, float & d09, float & d10, float & d11, - float & d12, float & d13, float & d14, float & d15, - float & d16, float & d17, float & d18, float & d19, - float & d20, float & d21, float & d22, float & d23, - float & d24, float & d25, float & d26, float & d27, - float & d28, float & d29, float & d30, float & d31, - float & d32, float & d33, float & d34, float & d35, - float & d36, float & d37, float & d38, float & d39, - float & d40, float & d41, float & d42, float & d43, - float & d44, float & d45, float & d46, float & d47, - float & d48, float & d49, float & d50, float & d51, - float & d52, float & d53, float & d54, float & d55, - float & d56, float & d57, float & d58, float & d59, - float & d60, float & d61, float & d62, float & d63, - float & d64, float & d65, float & d66, float & d67, - float & d68, float & d69, float & d70, float & d71, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %76, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n144k32.f32.bf16.bf16 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71}," - " %72," - " %73," - " %74, %75," - " p, %77, %78, %79, %80;\n" - "}\n" - : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), - "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), - "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), - "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), - "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), - "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), - "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), - "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), - "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), - "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), - "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), - "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), - "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), - "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), - "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), - "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), - "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), - "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71) - : "l"(desc_a), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x144x32_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x144x32 F32+=BF16*BF16 -template < - GMMA::Major tnspA, - GMMA::Major tnspB, - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x144x32_F32BF16BF16_RS -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = float[72]; - - static_assert(tnspA == GMMA::Major::K, - "Register source operand A must have K major layout."); - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, - uint64_t const& desc_b, - float & d00, float & d01, float & d02, float & d03, - float & d04, float & d05, float & d06, float & d07, - float & d08, float & d09, float & d10, float & d11, - float & d12, float & d13, float & d14, float & d15, - float & d16, float & d17, float & d18, float & d19, - float & d20, float & d21, float & d22, float & d23, - float & d24, float & d25, float & d26, float & d27, - float & d28, float & d29, float & d30, float & d31, - float & d32, float & d33, float & d34, float & d35, - float & d36, float & d37, float & d38, float & d39, - float & d40, float & d41, float & d42, float & d43, - float & d44, float & d45, float & d46, float & d47, - float & d48, float & d49, float & d50, float & d51, - float & d52, float & d53, float & d54, float & d55, - float & d56, float & d57, float & d58, float & d59, - float & d60, float & d61, float & d62, float & d63, - float & d64, float & d65, float & d66, float & d67, - float & d68, float & d69, float & d70, float & d71, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %79, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n144k32.f32.bf16.bf16 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71}," - "{%72, %73, %74, %75}," - " %76," - " %77, %78," - " p, %80, %81, %82;\n" - "}\n" - : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), - "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), - "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), - "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), - "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), - "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), - "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), - "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), - "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), - "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), - "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), - "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), - "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), - "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), - "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), - "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), - "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), - "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x144x32_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x160x32 F32+=BF16*BF16 -template < - GMMA::Major tnspA, - GMMA::Major tnspB, - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x160x32_F32BF16BF16_SS -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = float[80]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - float & d00, float & d01, float & d02, float & d03, - float & d04, float & d05, float & d06, float & d07, - float & d08, float & d09, float & d10, float & d11, - float & d12, float & d13, float & d14, float & d15, - float & d16, float & d17, float & d18, float & d19, - float & d20, float & d21, float & d22, float & d23, - float & d24, float & d25, float & d26, float & d27, - float & d28, float & d29, float & d30, float & d31, - float & d32, float & d33, float & d34, float & d35, - float & d36, float & d37, float & d38, float & d39, - float & d40, float & d41, float & d42, float & d43, - float & d44, float & d45, float & d46, float & d47, - float & d48, float & d49, float & d50, float & d51, - float & d52, float & d53, float & d54, float & d55, - float & d56, float & d57, float & d58, float & d59, - float & d60, float & d61, float & d62, float & d63, - float & d64, float & d65, float & d66, float & d67, - float & d68, float & d69, float & d70, float & d71, - float & d72, float & d73, float & d74, float & d75, - float & d76, float & d77, float & d78, float & d79, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %84, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n160k32.f32.bf16.bf16 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79}," - " %80," - " %81," - " %82, %83," - " p, %85, %86, %87, %88;\n" - "}\n" - : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), - "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), - "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), - "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), - "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), - "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), - "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), - "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), - "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), - "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), - "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), - "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), - "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), - "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), - "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), - "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), - "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), - "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), - "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), - "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79) - : "l"(desc_a), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x160x32_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x160x32 F32+=BF16*BF16 -template < - GMMA::Major tnspA, - GMMA::Major tnspB, - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x160x32_F32BF16BF16_RS -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = float[80]; - - static_assert(tnspA == GMMA::Major::K, - "Register source operand A must have K major layout."); - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, - uint64_t const& desc_b, - float & d00, float & d01, float & d02, float & d03, - float & d04, float & d05, float & d06, float & d07, - float & d08, float & d09, float & d10, float & d11, - float & d12, float & d13, float & d14, float & d15, - float & d16, float & d17, float & d18, float & d19, - float & d20, float & d21, float & d22, float & d23, - float & d24, float & d25, float & d26, float & d27, - float & d28, float & d29, float & d30, float & d31, - float & d32, float & d33, float & d34, float & d35, - float & d36, float & d37, float & d38, float & d39, - float & d40, float & d41, float & d42, float & d43, - float & d44, float & d45, float & d46, float & d47, - float & d48, float & d49, float & d50, float & d51, - float & d52, float & d53, float & d54, float & d55, - float & d56, float & d57, float & d58, float & d59, - float & d60, float & d61, float & d62, float & d63, - float & d64, float & d65, float & d66, float & d67, - float & d68, float & d69, float & d70, float & d71, - float & d72, float & d73, float & d74, float & d75, - float & d76, float & d77, float & d78, float & d79, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %87, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n160k32.f32.bf16.bf16 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79}," - "{%80, %81, %82, %83}," - " %84," - " %85, %86," - " p, %88, %89, %90;\n" - "}\n" - : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), - "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), - "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), - "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), - "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), - "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), - "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), - "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), - "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), - "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), - "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), - "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), - "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), - "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), - "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), - "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), - "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), - "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), - "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), - "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x160x32_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x176x32 F32+=BF16*BF16 -template < - GMMA::Major tnspA, - GMMA::Major tnspB, - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x176x32_F32BF16BF16_SS -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = float[88]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - float & d00, float & d01, float & d02, float & d03, - float & d04, float & d05, float & d06, float & d07, - float & d08, float & d09, float & d10, float & d11, - float & d12, float & d13, float & d14, float & d15, - float & d16, float & d17, float & d18, float & d19, - float & d20, float & d21, float & d22, float & d23, - float & d24, float & d25, float & d26, float & d27, - float & d28, float & d29, float & d30, float & d31, - float & d32, float & d33, float & d34, float & d35, - float & d36, float & d37, float & d38, float & d39, - float & d40, float & d41, float & d42, float & d43, - float & d44, float & d45, float & d46, float & d47, - float & d48, float & d49, float & d50, float & d51, - float & d52, float & d53, float & d54, float & d55, - float & d56, float & d57, float & d58, float & d59, - float & d60, float & d61, float & d62, float & d63, - float & d64, float & d65, float & d66, float & d67, - float & d68, float & d69, float & d70, float & d71, - float & d72, float & d73, float & d74, float & d75, - float & d76, float & d77, float & d78, float & d79, - float & d80, float & d81, float & d82, float & d83, - float & d84, float & d85, float & d86, float & d87, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %92, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n176k32.f32.bf16.bf16 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87}," - " %88," - " %89," - " %90, %91," - " p, %93, %94, %95, %96;\n" + " %88, %89, %90, %91, %92, %93, %94, %95}," + "{%96, %97, %98, %99}," + " %100," + " %101, %102," + " p, %104, %105;\n" "}\n" : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), @@ -7027,26913 +4503,364 @@ struct GMMA_64x176x32_F32BF16BF16_SS "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), - "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87) - : "l"(desc_a), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x176x32_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x176x32 F32+=BF16*BF16 -template < - GMMA::Major tnspA, - GMMA::Major tnspB, - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x176x32_F32BF16BF16_RS -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = float[88]; - - static_assert(tnspA == GMMA::Major::K, - "Register source operand A must have K major layout."); - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, - uint64_t const& desc_b, - float & d00, float & d01, float & d02, float & d03, - float & d04, float & d05, float & d06, float & d07, - float & d08, float & d09, float & d10, float & d11, - float & d12, float & d13, float & d14, float & d15, - float & d16, float & d17, float & d18, float & d19, - float & d20, float & d21, float & d22, float & d23, - float & d24, float & d25, float & d26, float & d27, - float & d28, float & d29, float & d30, float & d31, - float & d32, float & d33, float & d34, float & d35, - float & d36, float & d37, float & d38, float & d39, - float & d40, float & d41, float & d42, float & d43, - float & d44, float & d45, float & d46, float & d47, - float & d48, float & d49, float & d50, float & d51, - float & d52, float & d53, float & d54, float & d55, - float & d56, float & d57, float & d58, float & d59, - float & d60, float & d61, float & d62, float & d63, - float & d64, float & d65, float & d66, float & d67, - float & d68, float & d69, float & d70, float & d71, - float & d72, float & d73, float & d74, float & d75, - float & d76, float & d77, float & d78, float & d79, - float & d80, float & d81, float & d82, float & d83, - float & d84, float & d85, float & d86, float & d87, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %95, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n176k32.f32.bf16.bf16 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87}," - "{%88, %89, %90, %91}," - " %92," - " %93, %94," - " p, %96, %97, %98;\n" - "}\n" - : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), - "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), - "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), - "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), - "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), - "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), - "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), - "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), - "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), - "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), - "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), - "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), - "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), - "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), - "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), - "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), - "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), - "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), - "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), - "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), - "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), - "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x176x32_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// SPARSE GMMA 64x192x32 F32+=BF16*BF16 -template < - GMMA::Major tnspA, - GMMA::Major tnspB, - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x192x32_F32BF16BF16_SS -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = float[96]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - float & d00, float & d01, float & d02, float & d03, - float & d04, float & d05, float & d06, float & d07, - float & d08, float & d09, float & d10, float & d11, - float & d12, float & d13, float & d14, float & d15, - float & d16, float & d17, float & d18, float & d19, - float & d20, float & d21, float & d22, float & d23, - float & d24, float & d25, float & d26, float & d27, - float & d28, float & d29, float & d30, float & d31, - float & d32, float & d33, float & d34, float & d35, - float & d36, float & d37, float & d38, float & d39, - float & d40, float & d41, float & d42, float & d43, - float & d44, float & d45, float & d46, float & d47, - float & d48, float & d49, float & d50, float & d51, - float & d52, float & d53, float & d54, float & d55, - float & d56, float & d57, float & d58, float & d59, - float & d60, float & d61, float & d62, float & d63, - float & d64, float & d65, float & d66, float & d67, - float & d68, float & d69, float & d70, float & d71, - float & d72, float & d73, float & d74, float & d75, - float & d76, float & d77, float & d78, float & d79, - float & d80, float & d81, float & d82, float & d83, - float & d84, float & d85, float & d86, float & d87, - float & d88, float & d89, float & d90, float & d91, - float & d92, float & d93, float & d94, float & d95, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %100, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n192k32.f32.bf16.bf16 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87, " - " %88, %89, %90, %91, %92, %93, %94, %95}," - " %96," - " %97," - " %98, %99," - " p, %101, %102, %103, %104;\n" - "}\n" - : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), - "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), - "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), - "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), - "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), - "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), - "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), - "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), - "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), - "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), - "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), - "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), - "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), - "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), - "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), - "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), - "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), - "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), - "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), - "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), - "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), - "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), - "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91), - "+f"(d92), "+f"(d93), "+f"(d94), "+f"(d95) - : "l"(desc_a), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x192x32_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// SPARSE GMMA 64x192x32 F32+=BF16*BF16 -template < - GMMA::Major tnspA, - GMMA::Major tnspB, - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x192x32_F32BF16BF16_RS -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = float[96]; - - static_assert(tnspA == GMMA::Major::K, - "Register source operand A must have K major layout."); - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, - uint64_t const& desc_b, - float & d00, float & d01, float & d02, float & d03, - float & d04, float & d05, float & d06, float & d07, - float & d08, float & d09, float & d10, float & d11, - float & d12, float & d13, float & d14, float & d15, - float & d16, float & d17, float & d18, float & d19, - float & d20, float & d21, float & d22, float & d23, - float & d24, float & d25, float & d26, float & d27, - float & d28, float & d29, float & d30, float & d31, - float & d32, float & d33, float & d34, float & d35, - float & d36, float & d37, float & d38, float & d39, - float & d40, float & d41, float & d42, float & d43, - float & d44, float & d45, float & d46, float & d47, - float & d48, float & d49, float & d50, float & d51, - float & d52, float & d53, float & d54, float & d55, - float & d56, float & d57, float & d58, float & d59, - float & d60, float & d61, float & d62, float & d63, - float & d64, float & d65, float & d66, float & d67, - float & d68, float & d69, float & d70, float & d71, - float & d72, float & d73, float & d74, float & d75, - float & d76, float & d77, float & d78, float & d79, - float & d80, float & d81, float & d82, float & d83, - float & d84, float & d85, float & d86, float & d87, - float & d88, float & d89, float & d90, float & d91, - float & d92, float & d93, float & d94, float & d95, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %103, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n192k32.f32.bf16.bf16 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87, " - " %88, %89, %90, %91, %92, %93, %94, %95}," - "{%96, %97, %98, %99}," - " %100," - " %101, %102," - " p, %104, %105, %106;\n" - "}\n" - : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), - "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), - "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), - "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), - "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), - "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), - "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), - "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), - "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), - "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), - "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), - "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), - "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), - "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), - "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), - "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), - "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), - "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), - "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), - "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), - "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), - "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), - "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91), - "+f"(d92), "+f"(d93), "+f"(d94), "+f"(d95) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x192x32_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x208x32 F32+=BF16*BF16 -template < - GMMA::Major tnspA, - GMMA::Major tnspB, - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x208x32_F32BF16BF16_SS -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = float[104]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - float & d000, float & d001, float & d002, float & d003, - float & d004, float & d005, float & d006, float & d007, - float & d008, float & d009, float & d010, float & d011, - float & d012, float & d013, float & d014, float & d015, - float & d016, float & d017, float & d018, float & d019, - float & d020, float & d021, float & d022, float & d023, - float & d024, float & d025, float & d026, float & d027, - float & d028, float & d029, float & d030, float & d031, - float & d032, float & d033, float & d034, float & d035, - float & d036, float & d037, float & d038, float & d039, - float & d040, float & d041, float & d042, float & d043, - float & d044, float & d045, float & d046, float & d047, - float & d048, float & d049, float & d050, float & d051, - float & d052, float & d053, float & d054, float & d055, - float & d056, float & d057, float & d058, float & d059, - float & d060, float & d061, float & d062, float & d063, - float & d064, float & d065, float & d066, float & d067, - float & d068, float & d069, float & d070, float & d071, - float & d072, float & d073, float & d074, float & d075, - float & d076, float & d077, float & d078, float & d079, - float & d080, float & d081, float & d082, float & d083, - float & d084, float & d085, float & d086, float & d087, - float & d088, float & d089, float & d090, float & d091, - float & d092, float & d093, float & d094, float & d095, - float & d096, float & d097, float & d098, float & d099, - float & d100, float & d101, float & d102, float & d103, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %108, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n208k32.f32.bf16.bf16 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87, " - " %88, %89, %90, %91, %92, %93, %94, %95, " - " %96, %97, %98, %99, %100, %101, %102, %103}," - " %104," - " %105," - " %106, %107," - " p, %109, %110, %111, %112;\n" - "}\n" - : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), - "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), - "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), - "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), - "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), - "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), - "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), - "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), - "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), - "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), - "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), - "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), - "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), - "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), - "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), - "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), - "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), - "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), - "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), - "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), - "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), - "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), - "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), - "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), - "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), - "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103) - : "l"(desc_a), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x208x32_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x208x32 F32+=BF16*BF16 -template < - GMMA::Major tnspA, - GMMA::Major tnspB, - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x208x32_F32BF16BF16_RS -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = float[104]; - - static_assert(tnspA == GMMA::Major::K, - "Register source operand A must have K major layout."); - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, - uint64_t const& desc_b, - float & d000, float & d001, float & d002, float & d003, - float & d004, float & d005, float & d006, float & d007, - float & d008, float & d009, float & d010, float & d011, - float & d012, float & d013, float & d014, float & d015, - float & d016, float & d017, float & d018, float & d019, - float & d020, float & d021, float & d022, float & d023, - float & d024, float & d025, float & d026, float & d027, - float & d028, float & d029, float & d030, float & d031, - float & d032, float & d033, float & d034, float & d035, - float & d036, float & d037, float & d038, float & d039, - float & d040, float & d041, float & d042, float & d043, - float & d044, float & d045, float & d046, float & d047, - float & d048, float & d049, float & d050, float & d051, - float & d052, float & d053, float & d054, float & d055, - float & d056, float & d057, float & d058, float & d059, - float & d060, float & d061, float & d062, float & d063, - float & d064, float & d065, float & d066, float & d067, - float & d068, float & d069, float & d070, float & d071, - float & d072, float & d073, float & d074, float & d075, - float & d076, float & d077, float & d078, float & d079, - float & d080, float & d081, float & d082, float & d083, - float & d084, float & d085, float & d086, float & d087, - float & d088, float & d089, float & d090, float & d091, - float & d092, float & d093, float & d094, float & d095, - float & d096, float & d097, float & d098, float & d099, - float & d100, float & d101, float & d102, float & d103, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %111, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n208k32.f32.bf16.bf16 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87, " - " %88, %89, %90, %91, %92, %93, %94, %95, " - " %96, %97, %98, %99, %100, %101, %102, %103}," - "{%104, %105, %106, %107}," - " %108," - " %109, %110," - " p, %112, %113, %114;\n" - "}\n" - : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), - "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), - "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), - "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), - "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), - "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), - "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), - "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), - "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), - "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), - "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), - "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), - "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), - "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), - "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), - "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), - "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), - "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), - "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), - "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), - "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), - "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), - "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), - "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), - "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), - "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103) - : "r"(a000), "r"(a001), "r"(a002), "r"(a003), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x208x32_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x224x32 F32+=BF16*BF16 -template < - GMMA::Major tnspA, - GMMA::Major tnspB, - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x224x32_F32BF16BF16_SS -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = float[112]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - float & d000, float & d001, float & d002, float & d003, - float & d004, float & d005, float & d006, float & d007, - float & d008, float & d009, float & d010, float & d011, - float & d012, float & d013, float & d014, float & d015, - float & d016, float & d017, float & d018, float & d019, - float & d020, float & d021, float & d022, float & d023, - float & d024, float & d025, float & d026, float & d027, - float & d028, float & d029, float & d030, float & d031, - float & d032, float & d033, float & d034, float & d035, - float & d036, float & d037, float & d038, float & d039, - float & d040, float & d041, float & d042, float & d043, - float & d044, float & d045, float & d046, float & d047, - float & d048, float & d049, float & d050, float & d051, - float & d052, float & d053, float & d054, float & d055, - float & d056, float & d057, float & d058, float & d059, - float & d060, float & d061, float & d062, float & d063, - float & d064, float & d065, float & d066, float & d067, - float & d068, float & d069, float & d070, float & d071, - float & d072, float & d073, float & d074, float & d075, - float & d076, float & d077, float & d078, float & d079, - float & d080, float & d081, float & d082, float & d083, - float & d084, float & d085, float & d086, float & d087, - float & d088, float & d089, float & d090, float & d091, - float & d092, float & d093, float & d094, float & d095, - float & d096, float & d097, float & d098, float & d099, - float & d100, float & d101, float & d102, float & d103, - float & d104, float & d105, float & d106, float & d107, - float & d108, float & d109, float & d110, float & d111, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %116, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n224k32.f32.bf16.bf16 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87, " - " %88, %89, %90, %91, %92, %93, %94, %95, " - " %96, %97, %98, %99, %100, %101, %102, %103, " - " %104, %105, %106, %107, %108, %109, %110, %111}," - " %112," - " %113," - " %114, %115," - " p, %117, %118, %119, %120;\n" - "}\n" - : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), - "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), - "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), - "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), - "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), - "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), - "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), - "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), - "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), - "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), - "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), - "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), - "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), - "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), - "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), - "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), - "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), - "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), - "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), - "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), - "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), - "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), - "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), - "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), - "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), - "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), - "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), - "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111) - : "l"(desc_a), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x224x32_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x224x32 F32+=BF16*BF16 -template < - GMMA::Major tnspA, - GMMA::Major tnspB, - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x224x32_F32BF16BF16_RS -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = float[112]; - - static_assert(tnspA == GMMA::Major::K, - "Register source operand A must have K major layout."); - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, - uint64_t const& desc_b, - float & d000, float & d001, float & d002, float & d003, - float & d004, float & d005, float & d006, float & d007, - float & d008, float & d009, float & d010, float & d011, - float & d012, float & d013, float & d014, float & d015, - float & d016, float & d017, float & d018, float & d019, - float & d020, float & d021, float & d022, float & d023, - float & d024, float & d025, float & d026, float & d027, - float & d028, float & d029, float & d030, float & d031, - float & d032, float & d033, float & d034, float & d035, - float & d036, float & d037, float & d038, float & d039, - float & d040, float & d041, float & d042, float & d043, - float & d044, float & d045, float & d046, float & d047, - float & d048, float & d049, float & d050, float & d051, - float & d052, float & d053, float & d054, float & d055, - float & d056, float & d057, float & d058, float & d059, - float & d060, float & d061, float & d062, float & d063, - float & d064, float & d065, float & d066, float & d067, - float & d068, float & d069, float & d070, float & d071, - float & d072, float & d073, float & d074, float & d075, - float & d076, float & d077, float & d078, float & d079, - float & d080, float & d081, float & d082, float & d083, - float & d084, float & d085, float & d086, float & d087, - float & d088, float & d089, float & d090, float & d091, - float & d092, float & d093, float & d094, float & d095, - float & d096, float & d097, float & d098, float & d099, - float & d100, float & d101, float & d102, float & d103, - float & d104, float & d105, float & d106, float & d107, - float & d108, float & d109, float & d110, float & d111, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %119, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n224k32.f32.bf16.bf16 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87, " - " %88, %89, %90, %91, %92, %93, %94, %95, " - " %96, %97, %98, %99, %100, %101, %102, %103, " - " %104, %105, %106, %107, %108, %109, %110, %111}," - "{%112, %113, %114, %115}," - " %116," - " %117, %118," - " p, %120, %121, %122;\n" - "}\n" - : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), - "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), - "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), - "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), - "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), - "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), - "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), - "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), - "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), - "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), - "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), - "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), - "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), - "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), - "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), - "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), - "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), - "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), - "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), - "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), - "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), - "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), - "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), - "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), - "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), - "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), - "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), - "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111) - : "r"(a000), "r"(a001), "r"(a002), "r"(a003), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x224x32_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x240x32 F32+=BF16*BF16 -template < - GMMA::Major tnspA, - GMMA::Major tnspB, - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x240x32_F32BF16BF16_SS -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = float[120]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - float & d000, float & d001, float & d002, float & d003, - float & d004, float & d005, float & d006, float & d007, - float & d008, float & d009, float & d010, float & d011, - float & d012, float & d013, float & d014, float & d015, - float & d016, float & d017, float & d018, float & d019, - float & d020, float & d021, float & d022, float & d023, - float & d024, float & d025, float & d026, float & d027, - float & d028, float & d029, float & d030, float & d031, - float & d032, float & d033, float & d034, float & d035, - float & d036, float & d037, float & d038, float & d039, - float & d040, float & d041, float & d042, float & d043, - float & d044, float & d045, float & d046, float & d047, - float & d048, float & d049, float & d050, float & d051, - float & d052, float & d053, float & d054, float & d055, - float & d056, float & d057, float & d058, float & d059, - float & d060, float & d061, float & d062, float & d063, - float & d064, float & d065, float & d066, float & d067, - float & d068, float & d069, float & d070, float & d071, - float & d072, float & d073, float & d074, float & d075, - float & d076, float & d077, float & d078, float & d079, - float & d080, float & d081, float & d082, float & d083, - float & d084, float & d085, float & d086, float & d087, - float & d088, float & d089, float & d090, float & d091, - float & d092, float & d093, float & d094, float & d095, - float & d096, float & d097, float & d098, float & d099, - float & d100, float & d101, float & d102, float & d103, - float & d104, float & d105, float & d106, float & d107, - float & d108, float & d109, float & d110, float & d111, - float & d112, float & d113, float & d114, float & d115, - float & d116, float & d117, float & d118, float & d119, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %124, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n240k32.f32.bf16.bf16 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87, " - " %88, %89, %90, %91, %92, %93, %94, %95, " - " %96, %97, %98, %99, %100, %101, %102, %103, " - " %104, %105, %106, %107, %108, %109, %110, %111, " - " %112, %113, %114, %115, %116, %117, %118, %119}," - " %120," - " %121," - " %122, %123," - " p, %125, %126, %127, %128;\n" - "}\n" - : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), - "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), - "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), - "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), - "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), - "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), - "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), - "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), - "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), - "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), - "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), - "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), - "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), - "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), - "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), - "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), - "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), - "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), - "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), - "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), - "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), - "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), - "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), - "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), - "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), - "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), - "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), - "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), - "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), - "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119) - : "l"(desc_a), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x240x32_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x240x32 F32+=BF16*BF16 -template < - GMMA::Major tnspA, - GMMA::Major tnspB, - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x240x32_F32BF16BF16_RS -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = float[120]; - - static_assert(tnspA == GMMA::Major::K, - "Register source operand A must have K major layout."); - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, - uint64_t const& desc_b, - float & d000, float & d001, float & d002, float & d003, - float & d004, float & d005, float & d006, float & d007, - float & d008, float & d009, float & d010, float & d011, - float & d012, float & d013, float & d014, float & d015, - float & d016, float & d017, float & d018, float & d019, - float & d020, float & d021, float & d022, float & d023, - float & d024, float & d025, float & d026, float & d027, - float & d028, float & d029, float & d030, float & d031, - float & d032, float & d033, float & d034, float & d035, - float & d036, float & d037, float & d038, float & d039, - float & d040, float & d041, float & d042, float & d043, - float & d044, float & d045, float & d046, float & d047, - float & d048, float & d049, float & d050, float & d051, - float & d052, float & d053, float & d054, float & d055, - float & d056, float & d057, float & d058, float & d059, - float & d060, float & d061, float & d062, float & d063, - float & d064, float & d065, float & d066, float & d067, - float & d068, float & d069, float & d070, float & d071, - float & d072, float & d073, float & d074, float & d075, - float & d076, float & d077, float & d078, float & d079, - float & d080, float & d081, float & d082, float & d083, - float & d084, float & d085, float & d086, float & d087, - float & d088, float & d089, float & d090, float & d091, - float & d092, float & d093, float & d094, float & d095, - float & d096, float & d097, float & d098, float & d099, - float & d100, float & d101, float & d102, float & d103, - float & d104, float & d105, float & d106, float & d107, - float & d108, float & d109, float & d110, float & d111, - float & d112, float & d113, float & d114, float & d115, - float & d116, float & d117, float & d118, float & d119, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %127, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n240k32.f32.bf16.bf16 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87, " - " %88, %89, %90, %91, %92, %93, %94, %95, " - " %96, %97, %98, %99, %100, %101, %102, %103, " - " %104, %105, %106, %107, %108, %109, %110, %111, " - " %112, %113, %114, %115, %116, %117, %118, %119}," - "{%120, %121, %122, %123}," - " %124," - " %125, %126," - " p, %128, %129, %130;\n" - "}\n" - : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), - "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), - "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), - "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), - "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), - "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), - "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), - "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), - "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), - "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), - "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), - "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), - "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), - "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), - "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), - "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), - "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), - "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), - "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), - "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), - "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), - "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), - "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), - "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), - "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), - "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), - "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), - "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), - "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), - "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119) - : "r"(a000), "r"(a001), "r"(a002), "r"(a003), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x240x32_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// SPARSE GMMA 64x256x32 F32+=BF16*BF16 -template < - GMMA::Major tnspA, - GMMA::Major tnspB, - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x256x32_F32BF16BF16_SS -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = float[128]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - float & d000, float & d001, float & d002, float & d003, - float & d004, float & d005, float & d006, float & d007, - float & d008, float & d009, float & d010, float & d011, - float & d012, float & d013, float & d014, float & d015, - float & d016, float & d017, float & d018, float & d019, - float & d020, float & d021, float & d022, float & d023, - float & d024, float & d025, float & d026, float & d027, - float & d028, float & d029, float & d030, float & d031, - float & d032, float & d033, float & d034, float & d035, - float & d036, float & d037, float & d038, float & d039, - float & d040, float & d041, float & d042, float & d043, - float & d044, float & d045, float & d046, float & d047, - float & d048, float & d049, float & d050, float & d051, - float & d052, float & d053, float & d054, float & d055, - float & d056, float & d057, float & d058, float & d059, - float & d060, float & d061, float & d062, float & d063, - float & d064, float & d065, float & d066, float & d067, - float & d068, float & d069, float & d070, float & d071, - float & d072, float & d073, float & d074, float & d075, - float & d076, float & d077, float & d078, float & d079, - float & d080, float & d081, float & d082, float & d083, - float & d084, float & d085, float & d086, float & d087, - float & d088, float & d089, float & d090, float & d091, - float & d092, float & d093, float & d094, float & d095, - float & d096, float & d097, float & d098, float & d099, - float & d100, float & d101, float & d102, float & d103, - float & d104, float & d105, float & d106, float & d107, - float & d108, float & d109, float & d110, float & d111, - float & d112, float & d113, float & d114, float & d115, - float & d116, float & d117, float & d118, float & d119, - float & d120, float & d121, float & d122, float & d123, - float & d124, float & d125, float & d126, float & d127, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %132, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n256k32.f32.bf16.bf16 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87, " - " %88, %89, %90, %91, %92, %93, %94, %95, " - " %96, %97, %98, %99, %100, %101, %102, %103, " - " %104, %105, %106, %107, %108, %109, %110, %111, " - " %112, %113, %114, %115, %116, %117, %118, %119, " - " %120, %121, %122, %123, %124, %125, %126, %127}," - " %128," - " %129," - " %130, %131," - " p, %133, %134, %135, %136;\n" - "}\n" - : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), - "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), - "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), - "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), - "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), - "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), - "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), - "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), - "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), - "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), - "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), - "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), - "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), - "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), - "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), - "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), - "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), - "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), - "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), - "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), - "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), - "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), - "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), - "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), - "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), - "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), - "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), - "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), - "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), - "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), - "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123), - "+f"(d124), "+f"(d125), "+f"(d126), "+f"(d127) - : "l"(desc_a), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x256x32_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// SPARSE GMMA 64x256x32 F32+=BF16*BF16 -template < - GMMA::Major tnspA, - GMMA::Major tnspB, - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x256x32_F32BF16BF16_RS -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = float[128]; - - static_assert(tnspA == GMMA::Major::K, - "Register source operand A must have K major layout."); - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, - uint64_t const& desc_b, - float & d000, float & d001, float & d002, float & d003, - float & d004, float & d005, float & d006, float & d007, - float & d008, float & d009, float & d010, float & d011, - float & d012, float & d013, float & d014, float & d015, - float & d016, float & d017, float & d018, float & d019, - float & d020, float & d021, float & d022, float & d023, - float & d024, float & d025, float & d026, float & d027, - float & d028, float & d029, float & d030, float & d031, - float & d032, float & d033, float & d034, float & d035, - float & d036, float & d037, float & d038, float & d039, - float & d040, float & d041, float & d042, float & d043, - float & d044, float & d045, float & d046, float & d047, - float & d048, float & d049, float & d050, float & d051, - float & d052, float & d053, float & d054, float & d055, - float & d056, float & d057, float & d058, float & d059, - float & d060, float & d061, float & d062, float & d063, - float & d064, float & d065, float & d066, float & d067, - float & d068, float & d069, float & d070, float & d071, - float & d072, float & d073, float & d074, float & d075, - float & d076, float & d077, float & d078, float & d079, - float & d080, float & d081, float & d082, float & d083, - float & d084, float & d085, float & d086, float & d087, - float & d088, float & d089, float & d090, float & d091, - float & d092, float & d093, float & d094, float & d095, - float & d096, float & d097, float & d098, float & d099, - float & d100, float & d101, float & d102, float & d103, - float & d104, float & d105, float & d106, float & d107, - float & d108, float & d109, float & d110, float & d111, - float & d112, float & d113, float & d114, float & d115, - float & d116, float & d117, float & d118, float & d119, - float & d120, float & d121, float & d122, float & d123, - float & d124, float & d125, float & d126, float & d127, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %135, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n256k32.f32.bf16.bf16 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87, " - " %88, %89, %90, %91, %92, %93, %94, %95, " - " %96, %97, %98, %99, %100, %101, %102, %103, " - " %104, %105, %106, %107, %108, %109, %110, %111, " - " %112, %113, %114, %115, %116, %117, %118, %119, " - " %120, %121, %122, %123, %124, %125, %126, %127}," - "{%128, %129, %130, %131}," - " %132," - " %133, %134," - " p, %136, %137, %138;\n" - "}\n" - : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), - "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), - "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), - "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), - "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), - "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), - "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), - "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), - "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), - "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), - "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), - "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), - "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), - "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), - "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), - "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), - "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), - "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), - "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), - "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), - "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), - "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), - "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), - "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), - "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), - "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), - "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), - "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), - "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), - "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), - "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123), - "+f"(d124), "+f"(d125), "+f"(d126), "+f"(d127) - : "r"(a000), "r"(a001), "r"(a002), "r"(a003), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x256x32_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// SPARSE GMMA 64x8x16 TN F32+=TF32*TF32 -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x8x16_F32TF32TF32_SS_TN -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = float[4]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - float & d0, float & d1, float & d2, float & d3, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %8, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n8k16.f32.tf32.tf32 " - "{%0, %1, %2, %3}," - " %4," - " %5," - " %6, %7," - " p, %9, %10;\n" - "}\n" - : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3) - : "l"(desc_a), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x8x16_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// SPARSE GMMA 64x8x16 TN F32+=TF32*TF32 -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x8x16_F32TF32TF32_RS_TN -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = float[4]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, - uint64_t const& desc_b, - float & d0, float & d1, float & d2, float & d3, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %11, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n8k16.f32.tf32.tf32 " - "{%0, %1, %2, %3}," - "{%4, %5, %6, %7}," - " %8," - " %9, %10," - " p, %12, %13;\n" - "}\n" - : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3) - : "r"(a0), "r"(a1), "r"(a2), "r"(a3), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x8x16_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// SPARSE GMMA 64x16x16 TN F32+=TF32*TF32 -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x16x16_F32TF32TF32_SS_TN -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = float[8]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - float & d0, float & d1, float & d2, float & d3, - float & d4, float & d5, float & d6, float & d7, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %12, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n16k16.f32.tf32.tf32 " - "{%0, %1, %2, %3, %4, %5, %6, %7}," - " %8," - " %9," - " %10, %11," - " p, %13, %14;\n" - "}\n" - : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3), - "+f"(d4), "+f"(d5), "+f"(d6), "+f"(d7) - : "l"(desc_a), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x16x16_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// SPARSE GMMA 64x16x16 TN F32+=TF32*TF32 -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x16x16_F32TF32TF32_RS_TN -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = float[8]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, - uint64_t const& desc_b, - float & d0, float & d1, float & d2, float & d3, - float & d4, float & d5, float & d6, float & d7, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %15, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n16k16.f32.tf32.tf32 " - "{%0, %1, %2, %3, %4, %5, %6, %7}," - "{%8, %9, %10, %11}," - " %12," - " %13, %14," - " p, %16, %17;\n" - "}\n" - : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3), - "+f"(d4), "+f"(d5), "+f"(d6), "+f"(d7) - : "r"(a0), "r"(a1), "r"(a2), "r"(a3), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x16x16_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// SPARSE GMMA 64x32x16 TN F32+=TF32*TF32 -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x32x16_F32TF32TF32_SS_TN -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = float[16]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - float & d00, float & d01, float & d02, float & d03, - float & d04, float & d05, float & d06, float & d07, - float & d08, float & d09, float & d10, float & d11, - float & d12, float & d13, float & d14, float & d15, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %20, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n32k16.f32.tf32.tf32 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15}," - " %16," - " %17," - " %18, %19," - " p, %21, %22;\n" - "}\n" - : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), - "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), - "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), - "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15) - : "l"(desc_a), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x32x16_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// SPARSE GMMA 64x32x16 TN F32+=TF32*TF32 -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x32x16_F32TF32TF32_RS_TN -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = float[16]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, - uint64_t const& desc_b, - float & d00, float & d01, float & d02, float & d03, - float & d04, float & d05, float & d06, float & d07, - float & d08, float & d09, float & d10, float & d11, - float & d12, float & d13, float & d14, float & d15, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %23, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n32k16.f32.tf32.tf32 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15}," - "{%16, %17, %18, %19}," - " %20," - " %21, %22," - " p, %24, %25;\n" - "}\n" - : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), - "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), - "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), - "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x32x16_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x48x16 TN F32+=TF32*TF32 -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x48x16_F32TF32TF32_SS_TN -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = float[24]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - float & d00, float & d01, float & d02, float & d03, - float & d04, float & d05, float & d06, float & d07, - float & d08, float & d09, float & d10, float & d11, - float & d12, float & d13, float & d14, float & d15, - float & d16, float & d17, float & d18, float & d19, - float & d20, float & d21, float & d22, float & d23, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %28, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n48k16.f32.tf32.tf32 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23}," - " %24," - " %25," - " %26, %27," - " p, %29, %30;\n" - "}\n" - : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), - "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), - "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), - "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), - "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), - "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23) - : "l"(desc_a), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x48x16_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x48x16 TN F32+=TF32*TF32 -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x48x16_F32TF32TF32_RS_TN -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = float[24]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, - uint64_t const& desc_b, - float & d00, float & d01, float & d02, float & d03, - float & d04, float & d05, float & d06, float & d07, - float & d08, float & d09, float & d10, float & d11, - float & d12, float & d13, float & d14, float & d15, - float & d16, float & d17, float & d18, float & d19, - float & d20, float & d21, float & d22, float & d23, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %31, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n48k16.f32.tf32.tf32 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23}," - "{%24, %25, %26, %27}," - " %28," - " %29, %30," - " p, %32, %33;\n" - "}\n" - : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), - "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), - "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), - "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), - "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), - "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x48x16_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// SPARSE GMMA 64x64x16 TN F32+=TF32*TF32 -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x64x16_F32TF32TF32_SS_TN -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = float[32]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - float & d00, float & d01, float & d02, float & d03, - float & d04, float & d05, float & d06, float & d07, - float & d08, float & d09, float & d10, float & d11, - float & d12, float & d13, float & d14, float & d15, - float & d16, float & d17, float & d18, float & d19, - float & d20, float & d21, float & d22, float & d23, - float & d24, float & d25, float & d26, float & d27, - float & d28, float & d29, float & d30, float & d31, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %36, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n64k16.f32.tf32.tf32 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31}," - " %32," - " %33," - " %34, %35," - " p, %37, %38;\n" - "}\n" - : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), - "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), - "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), - "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), - "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), - "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), - "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), - "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31) - : "l"(desc_a), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x64x16_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// SPARSE GMMA 64x64x16 TN F32+=TF32*TF32 -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x64x16_F32TF32TF32_RS_TN -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = float[32]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, - uint64_t const& desc_b, - float & d00, float & d01, float & d02, float & d03, - float & d04, float & d05, float & d06, float & d07, - float & d08, float & d09, float & d10, float & d11, - float & d12, float & d13, float & d14, float & d15, - float & d16, float & d17, float & d18, float & d19, - float & d20, float & d21, float & d22, float & d23, - float & d24, float & d25, float & d26, float & d27, - float & d28, float & d29, float & d30, float & d31, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %39, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n64k16.f32.tf32.tf32 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31}," - "{%32, %33, %34, %35}," - " %36," - " %37, %38," - " p, %40, %41;\n" - "}\n" - : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), - "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), - "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), - "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), - "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), - "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), - "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), - "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x64x16_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x80x16 TN F32+=TF32*TF32 -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x80x16_F32TF32TF32_SS_TN -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = float[40]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - float & d00, float & d01, float & d02, float & d03, - float & d04, float & d05, float & d06, float & d07, - float & d08, float & d09, float & d10, float & d11, - float & d12, float & d13, float & d14, float & d15, - float & d16, float & d17, float & d18, float & d19, - float & d20, float & d21, float & d22, float & d23, - float & d24, float & d25, float & d26, float & d27, - float & d28, float & d29, float & d30, float & d31, - float & d32, float & d33, float & d34, float & d35, - float & d36, float & d37, float & d38, float & d39, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %44, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n80k16.f32.tf32.tf32 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39}," - " %40," - " %41," - " %42, %43," - " p, %45, %46;\n" - "}\n" - : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), - "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), - "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), - "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), - "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), - "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), - "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), - "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), - "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), - "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39) - : "l"(desc_a), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x80x16_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x80x16 TN F32+=TF32*TF32 -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x80x16_F32TF32TF32_RS_TN -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = float[40]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, - uint64_t const& desc_b, - float & d00, float & d01, float & d02, float & d03, - float & d04, float & d05, float & d06, float & d07, - float & d08, float & d09, float & d10, float & d11, - float & d12, float & d13, float & d14, float & d15, - float & d16, float & d17, float & d18, float & d19, - float & d20, float & d21, float & d22, float & d23, - float & d24, float & d25, float & d26, float & d27, - float & d28, float & d29, float & d30, float & d31, - float & d32, float & d33, float & d34, float & d35, - float & d36, float & d37, float & d38, float & d39, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %47, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n80k16.f32.tf32.tf32 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39}," - "{%40, %41, %42, %43}," - " %44," - " %45, %46," - " p, %48, %49;\n" - "}\n" - : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), - "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), - "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), - "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), - "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), - "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), - "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), - "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), - "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), - "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x80x16_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// SPARSE GMMA 64x96x16 TN F32+=TF32*TF32 -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x96x16_F32TF32TF32_SS_TN -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = float[48]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - float & d00, float & d01, float & d02, float & d03, - float & d04, float & d05, float & d06, float & d07, - float & d08, float & d09, float & d10, float & d11, - float & d12, float & d13, float & d14, float & d15, - float & d16, float & d17, float & d18, float & d19, - float & d20, float & d21, float & d22, float & d23, - float & d24, float & d25, float & d26, float & d27, - float & d28, float & d29, float & d30, float & d31, - float & d32, float & d33, float & d34, float & d35, - float & d36, float & d37, float & d38, float & d39, - float & d40, float & d41, float & d42, float & d43, - float & d44, float & d45, float & d46, float & d47, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %52, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n96k16.f32.tf32.tf32 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47}," - " %48," - " %49," - " %50, %51," - " p, %53, %54;\n" - "}\n" - : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), - "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), - "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), - "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), - "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), - "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), - "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), - "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), - "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), - "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), - "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), - "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47) - : "l"(desc_a), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x96x16_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// SPARSE GMMA 64x96x16 TN F32+=TF32*TF32 -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x96x16_F32TF32TF32_RS_TN -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = float[48]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, - uint64_t const& desc_b, - float & d00, float & d01, float & d02, float & d03, - float & d04, float & d05, float & d06, float & d07, - float & d08, float & d09, float & d10, float & d11, - float & d12, float & d13, float & d14, float & d15, - float & d16, float & d17, float & d18, float & d19, - float & d20, float & d21, float & d22, float & d23, - float & d24, float & d25, float & d26, float & d27, - float & d28, float & d29, float & d30, float & d31, - float & d32, float & d33, float & d34, float & d35, - float & d36, float & d37, float & d38, float & d39, - float & d40, float & d41, float & d42, float & d43, - float & d44, float & d45, float & d46, float & d47, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %55, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n96k16.f32.tf32.tf32 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47}," - "{%48, %49, %50, %51}," - " %52," - " %53, %54," - " p, %56, %57;\n" - "}\n" - : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), - "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), - "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), - "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), - "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), - "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), - "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), - "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), - "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), - "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), - "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), - "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x96x16_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x112x16 TN F32+=TF32*TF32 -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x112x16_F32TF32TF32_SS_TN -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = float[56]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - float & d00, float & d01, float & d02, float & d03, - float & d04, float & d05, float & d06, float & d07, - float & d08, float & d09, float & d10, float & d11, - float & d12, float & d13, float & d14, float & d15, - float & d16, float & d17, float & d18, float & d19, - float & d20, float & d21, float & d22, float & d23, - float & d24, float & d25, float & d26, float & d27, - float & d28, float & d29, float & d30, float & d31, - float & d32, float & d33, float & d34, float & d35, - float & d36, float & d37, float & d38, float & d39, - float & d40, float & d41, float & d42, float & d43, - float & d44, float & d45, float & d46, float & d47, - float & d48, float & d49, float & d50, float & d51, - float & d52, float & d53, float & d54, float & d55, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %60, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n112k16.f32.tf32.tf32 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55}," - " %56," - " %57," - " %58, %59," - " p, %61, %62;\n" - "}\n" - : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), - "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), - "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), - "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), - "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), - "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), - "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), - "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), - "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), - "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), - "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), - "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), - "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), - "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55) - : "l"(desc_a), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x112x16_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x112x16 TN F32+=TF32*TF32 -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x112x16_F32TF32TF32_RS_TN -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = float[56]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, - uint64_t const& desc_b, - float & d00, float & d01, float & d02, float & d03, - float & d04, float & d05, float & d06, float & d07, - float & d08, float & d09, float & d10, float & d11, - float & d12, float & d13, float & d14, float & d15, - float & d16, float & d17, float & d18, float & d19, - float & d20, float & d21, float & d22, float & d23, - float & d24, float & d25, float & d26, float & d27, - float & d28, float & d29, float & d30, float & d31, - float & d32, float & d33, float & d34, float & d35, - float & d36, float & d37, float & d38, float & d39, - float & d40, float & d41, float & d42, float & d43, - float & d44, float & d45, float & d46, float & d47, - float & d48, float & d49, float & d50, float & d51, - float & d52, float & d53, float & d54, float & d55, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %63, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n112k16.f32.tf32.tf32 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55}," - "{%56, %57, %58, %59}," - " %60," - " %61, %62," - " p, %64, %65;\n" - "}\n" - : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), - "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), - "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), - "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), - "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), - "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), - "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), - "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), - "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), - "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), - "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), - "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), - "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), - "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x112x16_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// SPARSE GMMA 64x128x16 TN F32+=TF32*TF32 -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x128x16_F32TF32TF32_SS_TN -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = float[64]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - float & d00, float & d01, float & d02, float & d03, - float & d04, float & d05, float & d06, float & d07, - float & d08, float & d09, float & d10, float & d11, - float & d12, float & d13, float & d14, float & d15, - float & d16, float & d17, float & d18, float & d19, - float & d20, float & d21, float & d22, float & d23, - float & d24, float & d25, float & d26, float & d27, - float & d28, float & d29, float & d30, float & d31, - float & d32, float & d33, float & d34, float & d35, - float & d36, float & d37, float & d38, float & d39, - float & d40, float & d41, float & d42, float & d43, - float & d44, float & d45, float & d46, float & d47, - float & d48, float & d49, float & d50, float & d51, - float & d52, float & d53, float & d54, float & d55, - float & d56, float & d57, float & d58, float & d59, - float & d60, float & d61, float & d62, float & d63, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %68, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n128k16.f32.tf32.tf32 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63}," - " %64," - " %65," - " %66, %67," - " p, %69, %70;\n" - "}\n" - : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), - "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), - "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), - "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), - "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), - "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), - "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), - "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), - "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), - "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), - "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), - "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), - "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), - "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), - "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), - "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63) - : "l"(desc_a), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x128x16_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// SPARSE GMMA 64x128x16 TN F32+=TF32*TF32 -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x128x16_F32TF32TF32_RS_TN -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = float[64]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, - uint64_t const& desc_b, - float & d00, float & d01, float & d02, float & d03, - float & d04, float & d05, float & d06, float & d07, - float & d08, float & d09, float & d10, float & d11, - float & d12, float & d13, float & d14, float & d15, - float & d16, float & d17, float & d18, float & d19, - float & d20, float & d21, float & d22, float & d23, - float & d24, float & d25, float & d26, float & d27, - float & d28, float & d29, float & d30, float & d31, - float & d32, float & d33, float & d34, float & d35, - float & d36, float & d37, float & d38, float & d39, - float & d40, float & d41, float & d42, float & d43, - float & d44, float & d45, float & d46, float & d47, - float & d48, float & d49, float & d50, float & d51, - float & d52, float & d53, float & d54, float & d55, - float & d56, float & d57, float & d58, float & d59, - float & d60, float & d61, float & d62, float & d63, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %71, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n128k16.f32.tf32.tf32 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63}," - "{%64, %65, %66, %67}," - " %68," - " %69, %70," - " p, %72, %73;\n" - "}\n" - : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), - "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), - "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), - "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), - "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), - "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), - "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), - "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), - "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), - "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), - "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), - "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), - "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), - "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), - "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), - "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x128x16_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x144x16 TN F32+=TF32*TF32 -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x144x16_F32TF32TF32_SS_TN -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = float[72]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - float & d00, float & d01, float & d02, float & d03, - float & d04, float & d05, float & d06, float & d07, - float & d08, float & d09, float & d10, float & d11, - float & d12, float & d13, float & d14, float & d15, - float & d16, float & d17, float & d18, float & d19, - float & d20, float & d21, float & d22, float & d23, - float & d24, float & d25, float & d26, float & d27, - float & d28, float & d29, float & d30, float & d31, - float & d32, float & d33, float & d34, float & d35, - float & d36, float & d37, float & d38, float & d39, - float & d40, float & d41, float & d42, float & d43, - float & d44, float & d45, float & d46, float & d47, - float & d48, float & d49, float & d50, float & d51, - float & d52, float & d53, float & d54, float & d55, - float & d56, float & d57, float & d58, float & d59, - float & d60, float & d61, float & d62, float & d63, - float & d64, float & d65, float & d66, float & d67, - float & d68, float & d69, float & d70, float & d71, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %76, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n144k16.f32.tf32.tf32 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71}," - " %72," - " %73," - " %74, %75," - " p, %77, %78;\n" - "}\n" - : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), - "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), - "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), - "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), - "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), - "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), - "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), - "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), - "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), - "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), - "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), - "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), - "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), - "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), - "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), - "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), - "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), - "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71) - : "l"(desc_a), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x144x16_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x144x16 TN F32+=TF32*TF32 -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x144x16_F32TF32TF32_RS_TN -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = float[72]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, - uint64_t const& desc_b, - float & d00, float & d01, float & d02, float & d03, - float & d04, float & d05, float & d06, float & d07, - float & d08, float & d09, float & d10, float & d11, - float & d12, float & d13, float & d14, float & d15, - float & d16, float & d17, float & d18, float & d19, - float & d20, float & d21, float & d22, float & d23, - float & d24, float & d25, float & d26, float & d27, - float & d28, float & d29, float & d30, float & d31, - float & d32, float & d33, float & d34, float & d35, - float & d36, float & d37, float & d38, float & d39, - float & d40, float & d41, float & d42, float & d43, - float & d44, float & d45, float & d46, float & d47, - float & d48, float & d49, float & d50, float & d51, - float & d52, float & d53, float & d54, float & d55, - float & d56, float & d57, float & d58, float & d59, - float & d60, float & d61, float & d62, float & d63, - float & d64, float & d65, float & d66, float & d67, - float & d68, float & d69, float & d70, float & d71, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %79, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n144k16.f32.tf32.tf32 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71}," - "{%72, %73, %74, %75}," - " %76," - " %77, %78," - " p, %80, %81;\n" - "}\n" - : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), - "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), - "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), - "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), - "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), - "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), - "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), - "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), - "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), - "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), - "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), - "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), - "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), - "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), - "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), - "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), - "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), - "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x144x16_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x160x16 TN F32+=TF32*TF32 -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x160x16_F32TF32TF32_SS_TN -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = float[80]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - float & d00, float & d01, float & d02, float & d03, - float & d04, float & d05, float & d06, float & d07, - float & d08, float & d09, float & d10, float & d11, - float & d12, float & d13, float & d14, float & d15, - float & d16, float & d17, float & d18, float & d19, - float & d20, float & d21, float & d22, float & d23, - float & d24, float & d25, float & d26, float & d27, - float & d28, float & d29, float & d30, float & d31, - float & d32, float & d33, float & d34, float & d35, - float & d36, float & d37, float & d38, float & d39, - float & d40, float & d41, float & d42, float & d43, - float & d44, float & d45, float & d46, float & d47, - float & d48, float & d49, float & d50, float & d51, - float & d52, float & d53, float & d54, float & d55, - float & d56, float & d57, float & d58, float & d59, - float & d60, float & d61, float & d62, float & d63, - float & d64, float & d65, float & d66, float & d67, - float & d68, float & d69, float & d70, float & d71, - float & d72, float & d73, float & d74, float & d75, - float & d76, float & d77, float & d78, float & d79, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %84, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n160k16.f32.tf32.tf32 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79}," - " %80," - " %81," - " %82, %83," - " p, %85, %86;\n" - "}\n" - : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), - "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), - "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), - "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), - "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), - "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), - "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), - "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), - "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), - "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), - "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), - "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), - "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), - "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), - "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), - "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), - "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), - "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), - "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), - "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79) - : "l"(desc_a), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x160x16_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x160x16 TN F32+=TF32*TF32 -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x160x16_F32TF32TF32_RS_TN -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = float[80]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, - uint64_t const& desc_b, - float & d00, float & d01, float & d02, float & d03, - float & d04, float & d05, float & d06, float & d07, - float & d08, float & d09, float & d10, float & d11, - float & d12, float & d13, float & d14, float & d15, - float & d16, float & d17, float & d18, float & d19, - float & d20, float & d21, float & d22, float & d23, - float & d24, float & d25, float & d26, float & d27, - float & d28, float & d29, float & d30, float & d31, - float & d32, float & d33, float & d34, float & d35, - float & d36, float & d37, float & d38, float & d39, - float & d40, float & d41, float & d42, float & d43, - float & d44, float & d45, float & d46, float & d47, - float & d48, float & d49, float & d50, float & d51, - float & d52, float & d53, float & d54, float & d55, - float & d56, float & d57, float & d58, float & d59, - float & d60, float & d61, float & d62, float & d63, - float & d64, float & d65, float & d66, float & d67, - float & d68, float & d69, float & d70, float & d71, - float & d72, float & d73, float & d74, float & d75, - float & d76, float & d77, float & d78, float & d79, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %87, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n160k16.f32.tf32.tf32 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79}," - "{%80, %81, %82, %83}," - " %84," - " %85, %86," - " p, %88, %89;\n" - "}\n" - : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), - "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), - "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), - "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), - "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), - "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), - "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), - "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), - "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), - "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), - "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), - "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), - "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), - "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), - "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), - "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), - "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), - "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), - "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), - "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x160x16_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x176x16 TN F32+=TF32*TF32 -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x176x16_F32TF32TF32_SS_TN -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = float[88]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - float & d00, float & d01, float & d02, float & d03, - float & d04, float & d05, float & d06, float & d07, - float & d08, float & d09, float & d10, float & d11, - float & d12, float & d13, float & d14, float & d15, - float & d16, float & d17, float & d18, float & d19, - float & d20, float & d21, float & d22, float & d23, - float & d24, float & d25, float & d26, float & d27, - float & d28, float & d29, float & d30, float & d31, - float & d32, float & d33, float & d34, float & d35, - float & d36, float & d37, float & d38, float & d39, - float & d40, float & d41, float & d42, float & d43, - float & d44, float & d45, float & d46, float & d47, - float & d48, float & d49, float & d50, float & d51, - float & d52, float & d53, float & d54, float & d55, - float & d56, float & d57, float & d58, float & d59, - float & d60, float & d61, float & d62, float & d63, - float & d64, float & d65, float & d66, float & d67, - float & d68, float & d69, float & d70, float & d71, - float & d72, float & d73, float & d74, float & d75, - float & d76, float & d77, float & d78, float & d79, - float & d80, float & d81, float & d82, float & d83, - float & d84, float & d85, float & d86, float & d87, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %92, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n176k16.f32.tf32.tf32 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87}," - " %88," - " %89," - " %90, %91," - " p, %93, %94;\n" - "}\n" - : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), - "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), - "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), - "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), - "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), - "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), - "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), - "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), - "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), - "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), - "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), - "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), - "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), - "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), - "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), - "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), - "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), - "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), - "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), - "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), - "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), - "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87) - : "l"(desc_a), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x176x16_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x176x16 TN F32+=TF32*TF32 -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x176x16_F32TF32TF32_RS_TN -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = float[88]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, - uint64_t const& desc_b, - float & d00, float & d01, float & d02, float & d03, - float & d04, float & d05, float & d06, float & d07, - float & d08, float & d09, float & d10, float & d11, - float & d12, float & d13, float & d14, float & d15, - float & d16, float & d17, float & d18, float & d19, - float & d20, float & d21, float & d22, float & d23, - float & d24, float & d25, float & d26, float & d27, - float & d28, float & d29, float & d30, float & d31, - float & d32, float & d33, float & d34, float & d35, - float & d36, float & d37, float & d38, float & d39, - float & d40, float & d41, float & d42, float & d43, - float & d44, float & d45, float & d46, float & d47, - float & d48, float & d49, float & d50, float & d51, - float & d52, float & d53, float & d54, float & d55, - float & d56, float & d57, float & d58, float & d59, - float & d60, float & d61, float & d62, float & d63, - float & d64, float & d65, float & d66, float & d67, - float & d68, float & d69, float & d70, float & d71, - float & d72, float & d73, float & d74, float & d75, - float & d76, float & d77, float & d78, float & d79, - float & d80, float & d81, float & d82, float & d83, - float & d84, float & d85, float & d86, float & d87, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %95, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n176k16.f32.tf32.tf32 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87}," - "{%88, %89, %90, %91}," - " %92," - " %93, %94," - " p, %96, %97;\n" - "}\n" - : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), - "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), - "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), - "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), - "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), - "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), - "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), - "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), - "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), - "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), - "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), - "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), - "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), - "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), - "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), - "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), - "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), - "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), - "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), - "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), - "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), - "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x176x16_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// SPARSE GMMA 64x192x16 TN F32+=TF32*TF32 -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x192x16_F32TF32TF32_SS_TN -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = float[96]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - float & d00, float & d01, float & d02, float & d03, - float & d04, float & d05, float & d06, float & d07, - float & d08, float & d09, float & d10, float & d11, - float & d12, float & d13, float & d14, float & d15, - float & d16, float & d17, float & d18, float & d19, - float & d20, float & d21, float & d22, float & d23, - float & d24, float & d25, float & d26, float & d27, - float & d28, float & d29, float & d30, float & d31, - float & d32, float & d33, float & d34, float & d35, - float & d36, float & d37, float & d38, float & d39, - float & d40, float & d41, float & d42, float & d43, - float & d44, float & d45, float & d46, float & d47, - float & d48, float & d49, float & d50, float & d51, - float & d52, float & d53, float & d54, float & d55, - float & d56, float & d57, float & d58, float & d59, - float & d60, float & d61, float & d62, float & d63, - float & d64, float & d65, float & d66, float & d67, - float & d68, float & d69, float & d70, float & d71, - float & d72, float & d73, float & d74, float & d75, - float & d76, float & d77, float & d78, float & d79, - float & d80, float & d81, float & d82, float & d83, - float & d84, float & d85, float & d86, float & d87, - float & d88, float & d89, float & d90, float & d91, - float & d92, float & d93, float & d94, float & d95, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %100, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n192k16.f32.tf32.tf32 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87, " - " %88, %89, %90, %91, %92, %93, %94, %95}," - " %96," - " %97," - " %98, %99," - " p, %101, %102;\n" - "}\n" - : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), - "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), - "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), - "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), - "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), - "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), - "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), - "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), - "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), - "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), - "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), - "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), - "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), - "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), - "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), - "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), - "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), - "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), - "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), - "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), - "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), - "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), - "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91), - "+f"(d92), "+f"(d93), "+f"(d94), "+f"(d95) - : "l"(desc_a), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x192x16_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// SPARSE GMMA 64x192x16 TN F32+=TF32*TF32 -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x192x16_F32TF32TF32_RS_TN -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = float[96]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, - uint64_t const& desc_b, - float & d00, float & d01, float & d02, float & d03, - float & d04, float & d05, float & d06, float & d07, - float & d08, float & d09, float & d10, float & d11, - float & d12, float & d13, float & d14, float & d15, - float & d16, float & d17, float & d18, float & d19, - float & d20, float & d21, float & d22, float & d23, - float & d24, float & d25, float & d26, float & d27, - float & d28, float & d29, float & d30, float & d31, - float & d32, float & d33, float & d34, float & d35, - float & d36, float & d37, float & d38, float & d39, - float & d40, float & d41, float & d42, float & d43, - float & d44, float & d45, float & d46, float & d47, - float & d48, float & d49, float & d50, float & d51, - float & d52, float & d53, float & d54, float & d55, - float & d56, float & d57, float & d58, float & d59, - float & d60, float & d61, float & d62, float & d63, - float & d64, float & d65, float & d66, float & d67, - float & d68, float & d69, float & d70, float & d71, - float & d72, float & d73, float & d74, float & d75, - float & d76, float & d77, float & d78, float & d79, - float & d80, float & d81, float & d82, float & d83, - float & d84, float & d85, float & d86, float & d87, - float & d88, float & d89, float & d90, float & d91, - float & d92, float & d93, float & d94, float & d95, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %103, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n192k16.f32.tf32.tf32 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87, " - " %88, %89, %90, %91, %92, %93, %94, %95}," - "{%96, %97, %98, %99}," - " %100," - " %101, %102," - " p, %104, %105;\n" - "}\n" - : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), - "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), - "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), - "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), - "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), - "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), - "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), - "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), - "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), - "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), - "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), - "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), - "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), - "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), - "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), - "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), - "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), - "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), - "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), - "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), - "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), - "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), - "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91), - "+f"(d92), "+f"(d93), "+f"(d94), "+f"(d95) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x192x16_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x208x16 TN F32+=TF32*TF32 -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x208x16_F32TF32TF32_SS_TN -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = float[104]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - float & d000, float & d001, float & d002, float & d003, - float & d004, float & d005, float & d006, float & d007, - float & d008, float & d009, float & d010, float & d011, - float & d012, float & d013, float & d014, float & d015, - float & d016, float & d017, float & d018, float & d019, - float & d020, float & d021, float & d022, float & d023, - float & d024, float & d025, float & d026, float & d027, - float & d028, float & d029, float & d030, float & d031, - float & d032, float & d033, float & d034, float & d035, - float & d036, float & d037, float & d038, float & d039, - float & d040, float & d041, float & d042, float & d043, - float & d044, float & d045, float & d046, float & d047, - float & d048, float & d049, float & d050, float & d051, - float & d052, float & d053, float & d054, float & d055, - float & d056, float & d057, float & d058, float & d059, - float & d060, float & d061, float & d062, float & d063, - float & d064, float & d065, float & d066, float & d067, - float & d068, float & d069, float & d070, float & d071, - float & d072, float & d073, float & d074, float & d075, - float & d076, float & d077, float & d078, float & d079, - float & d080, float & d081, float & d082, float & d083, - float & d084, float & d085, float & d086, float & d087, - float & d088, float & d089, float & d090, float & d091, - float & d092, float & d093, float & d094, float & d095, - float & d096, float & d097, float & d098, float & d099, - float & d100, float & d101, float & d102, float & d103, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %108, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n208k16.f32.tf32.tf32 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87, " - " %88, %89, %90, %91, %92, %93, %94, %95, " - " %96, %97, %98, %99, %100, %101, %102, %103}," - " %104," - " %105," - " %106, %107," - " p, %109, %110;\n" - "}\n" - : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), - "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), - "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), - "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), - "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), - "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), - "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), - "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), - "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), - "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), - "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), - "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), - "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), - "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), - "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), - "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), - "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), - "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), - "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), - "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), - "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), - "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), - "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), - "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), - "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), - "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103) - : "l"(desc_a), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x208x16_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x208x16 TN F32+=TF32*TF32 -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x208x16_F32TF32TF32_RS_TN -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = float[104]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, - uint64_t const& desc_b, - float & d000, float & d001, float & d002, float & d003, - float & d004, float & d005, float & d006, float & d007, - float & d008, float & d009, float & d010, float & d011, - float & d012, float & d013, float & d014, float & d015, - float & d016, float & d017, float & d018, float & d019, - float & d020, float & d021, float & d022, float & d023, - float & d024, float & d025, float & d026, float & d027, - float & d028, float & d029, float & d030, float & d031, - float & d032, float & d033, float & d034, float & d035, - float & d036, float & d037, float & d038, float & d039, - float & d040, float & d041, float & d042, float & d043, - float & d044, float & d045, float & d046, float & d047, - float & d048, float & d049, float & d050, float & d051, - float & d052, float & d053, float & d054, float & d055, - float & d056, float & d057, float & d058, float & d059, - float & d060, float & d061, float & d062, float & d063, - float & d064, float & d065, float & d066, float & d067, - float & d068, float & d069, float & d070, float & d071, - float & d072, float & d073, float & d074, float & d075, - float & d076, float & d077, float & d078, float & d079, - float & d080, float & d081, float & d082, float & d083, - float & d084, float & d085, float & d086, float & d087, - float & d088, float & d089, float & d090, float & d091, - float & d092, float & d093, float & d094, float & d095, - float & d096, float & d097, float & d098, float & d099, - float & d100, float & d101, float & d102, float & d103, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %111, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n208k16.f32.tf32.tf32 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87, " - " %88, %89, %90, %91, %92, %93, %94, %95, " - " %96, %97, %98, %99, %100, %101, %102, %103}," - "{%104, %105, %106, %107}," - " %108," - " %109, %110," - " p, %112, %113;\n" - "}\n" - : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), - "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), - "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), - "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), - "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), - "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), - "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), - "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), - "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), - "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), - "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), - "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), - "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), - "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), - "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), - "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), - "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), - "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), - "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), - "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), - "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), - "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), - "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), - "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), - "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), - "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103) - : "r"(a000), "r"(a001), "r"(a002), "r"(a003), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x208x16_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x224x16 TN F32+=TF32*TF32 -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x224x16_F32TF32TF32_SS_TN -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = float[112]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - float & d000, float & d001, float & d002, float & d003, - float & d004, float & d005, float & d006, float & d007, - float & d008, float & d009, float & d010, float & d011, - float & d012, float & d013, float & d014, float & d015, - float & d016, float & d017, float & d018, float & d019, - float & d020, float & d021, float & d022, float & d023, - float & d024, float & d025, float & d026, float & d027, - float & d028, float & d029, float & d030, float & d031, - float & d032, float & d033, float & d034, float & d035, - float & d036, float & d037, float & d038, float & d039, - float & d040, float & d041, float & d042, float & d043, - float & d044, float & d045, float & d046, float & d047, - float & d048, float & d049, float & d050, float & d051, - float & d052, float & d053, float & d054, float & d055, - float & d056, float & d057, float & d058, float & d059, - float & d060, float & d061, float & d062, float & d063, - float & d064, float & d065, float & d066, float & d067, - float & d068, float & d069, float & d070, float & d071, - float & d072, float & d073, float & d074, float & d075, - float & d076, float & d077, float & d078, float & d079, - float & d080, float & d081, float & d082, float & d083, - float & d084, float & d085, float & d086, float & d087, - float & d088, float & d089, float & d090, float & d091, - float & d092, float & d093, float & d094, float & d095, - float & d096, float & d097, float & d098, float & d099, - float & d100, float & d101, float & d102, float & d103, - float & d104, float & d105, float & d106, float & d107, - float & d108, float & d109, float & d110, float & d111, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %116, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n224k16.f32.tf32.tf32 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87, " - " %88, %89, %90, %91, %92, %93, %94, %95, " - " %96, %97, %98, %99, %100, %101, %102, %103, " - " %104, %105, %106, %107, %108, %109, %110, %111}," - " %112," - " %113," - " %114, %115," - " p, %117, %118;\n" - "}\n" - : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), - "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), - "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), - "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), - "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), - "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), - "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), - "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), - "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), - "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), - "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), - "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), - "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), - "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), - "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), - "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), - "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), - "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), - "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), - "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), - "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), - "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), - "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), - "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), - "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), - "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), - "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), - "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111) - : "l"(desc_a), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x224x16_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x224x16 TN F32+=TF32*TF32 -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x224x16_F32TF32TF32_RS_TN -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = float[112]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, - uint64_t const& desc_b, - float & d000, float & d001, float & d002, float & d003, - float & d004, float & d005, float & d006, float & d007, - float & d008, float & d009, float & d010, float & d011, - float & d012, float & d013, float & d014, float & d015, - float & d016, float & d017, float & d018, float & d019, - float & d020, float & d021, float & d022, float & d023, - float & d024, float & d025, float & d026, float & d027, - float & d028, float & d029, float & d030, float & d031, - float & d032, float & d033, float & d034, float & d035, - float & d036, float & d037, float & d038, float & d039, - float & d040, float & d041, float & d042, float & d043, - float & d044, float & d045, float & d046, float & d047, - float & d048, float & d049, float & d050, float & d051, - float & d052, float & d053, float & d054, float & d055, - float & d056, float & d057, float & d058, float & d059, - float & d060, float & d061, float & d062, float & d063, - float & d064, float & d065, float & d066, float & d067, - float & d068, float & d069, float & d070, float & d071, - float & d072, float & d073, float & d074, float & d075, - float & d076, float & d077, float & d078, float & d079, - float & d080, float & d081, float & d082, float & d083, - float & d084, float & d085, float & d086, float & d087, - float & d088, float & d089, float & d090, float & d091, - float & d092, float & d093, float & d094, float & d095, - float & d096, float & d097, float & d098, float & d099, - float & d100, float & d101, float & d102, float & d103, - float & d104, float & d105, float & d106, float & d107, - float & d108, float & d109, float & d110, float & d111, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %119, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n224k16.f32.tf32.tf32 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87, " - " %88, %89, %90, %91, %92, %93, %94, %95, " - " %96, %97, %98, %99, %100, %101, %102, %103, " - " %104, %105, %106, %107, %108, %109, %110, %111}," - "{%112, %113, %114, %115}," - " %116," - " %117, %118," - " p, %120, %121;\n" - "}\n" - : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), - "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), - "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), - "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), - "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), - "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), - "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), - "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), - "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), - "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), - "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), - "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), - "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), - "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), - "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), - "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), - "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), - "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), - "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), - "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), - "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), - "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), - "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), - "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), - "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), - "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), - "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), - "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111) - : "r"(a000), "r"(a001), "r"(a002), "r"(a003), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x224x16_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x240x16 TN F32+=TF32*TF32 -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x240x16_F32TF32TF32_SS_TN -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = float[120]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - float & d000, float & d001, float & d002, float & d003, - float & d004, float & d005, float & d006, float & d007, - float & d008, float & d009, float & d010, float & d011, - float & d012, float & d013, float & d014, float & d015, - float & d016, float & d017, float & d018, float & d019, - float & d020, float & d021, float & d022, float & d023, - float & d024, float & d025, float & d026, float & d027, - float & d028, float & d029, float & d030, float & d031, - float & d032, float & d033, float & d034, float & d035, - float & d036, float & d037, float & d038, float & d039, - float & d040, float & d041, float & d042, float & d043, - float & d044, float & d045, float & d046, float & d047, - float & d048, float & d049, float & d050, float & d051, - float & d052, float & d053, float & d054, float & d055, - float & d056, float & d057, float & d058, float & d059, - float & d060, float & d061, float & d062, float & d063, - float & d064, float & d065, float & d066, float & d067, - float & d068, float & d069, float & d070, float & d071, - float & d072, float & d073, float & d074, float & d075, - float & d076, float & d077, float & d078, float & d079, - float & d080, float & d081, float & d082, float & d083, - float & d084, float & d085, float & d086, float & d087, - float & d088, float & d089, float & d090, float & d091, - float & d092, float & d093, float & d094, float & d095, - float & d096, float & d097, float & d098, float & d099, - float & d100, float & d101, float & d102, float & d103, - float & d104, float & d105, float & d106, float & d107, - float & d108, float & d109, float & d110, float & d111, - float & d112, float & d113, float & d114, float & d115, - float & d116, float & d117, float & d118, float & d119, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %124, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n240k16.f32.tf32.tf32 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87, " - " %88, %89, %90, %91, %92, %93, %94, %95, " - " %96, %97, %98, %99, %100, %101, %102, %103, " - " %104, %105, %106, %107, %108, %109, %110, %111, " - " %112, %113, %114, %115, %116, %117, %118, %119}," - " %120," - " %121," - " %122, %123," - " p, %125, %126;\n" - "}\n" - : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), - "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), - "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), - "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), - "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), - "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), - "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), - "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), - "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), - "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), - "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), - "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), - "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), - "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), - "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), - "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), - "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), - "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), - "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), - "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), - "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), - "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), - "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), - "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), - "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), - "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), - "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), - "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), - "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), - "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119) - : "l"(desc_a), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x240x16_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x240x16 TN F32+=TF32*TF32 -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x240x16_F32TF32TF32_RS_TN -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = float[120]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, - uint64_t const& desc_b, - float & d000, float & d001, float & d002, float & d003, - float & d004, float & d005, float & d006, float & d007, - float & d008, float & d009, float & d010, float & d011, - float & d012, float & d013, float & d014, float & d015, - float & d016, float & d017, float & d018, float & d019, - float & d020, float & d021, float & d022, float & d023, - float & d024, float & d025, float & d026, float & d027, - float & d028, float & d029, float & d030, float & d031, - float & d032, float & d033, float & d034, float & d035, - float & d036, float & d037, float & d038, float & d039, - float & d040, float & d041, float & d042, float & d043, - float & d044, float & d045, float & d046, float & d047, - float & d048, float & d049, float & d050, float & d051, - float & d052, float & d053, float & d054, float & d055, - float & d056, float & d057, float & d058, float & d059, - float & d060, float & d061, float & d062, float & d063, - float & d064, float & d065, float & d066, float & d067, - float & d068, float & d069, float & d070, float & d071, - float & d072, float & d073, float & d074, float & d075, - float & d076, float & d077, float & d078, float & d079, - float & d080, float & d081, float & d082, float & d083, - float & d084, float & d085, float & d086, float & d087, - float & d088, float & d089, float & d090, float & d091, - float & d092, float & d093, float & d094, float & d095, - float & d096, float & d097, float & d098, float & d099, - float & d100, float & d101, float & d102, float & d103, - float & d104, float & d105, float & d106, float & d107, - float & d108, float & d109, float & d110, float & d111, - float & d112, float & d113, float & d114, float & d115, - float & d116, float & d117, float & d118, float & d119, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %127, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n240k16.f32.tf32.tf32 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87, " - " %88, %89, %90, %91, %92, %93, %94, %95, " - " %96, %97, %98, %99, %100, %101, %102, %103, " - " %104, %105, %106, %107, %108, %109, %110, %111, " - " %112, %113, %114, %115, %116, %117, %118, %119}," - "{%120, %121, %122, %123}," - " %124," - " %125, %126," - " p, %128, %129;\n" - "}\n" - : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), - "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), - "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), - "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), - "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), - "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), - "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), - "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), - "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), - "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), - "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), - "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), - "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), - "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), - "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), - "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), - "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), - "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), - "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), - "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), - "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), - "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), - "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), - "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), - "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), - "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), - "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), - "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), - "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), - "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119) - : "r"(a000), "r"(a001), "r"(a002), "r"(a003), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x240x16_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// SPARSE GMMA 64x256x16 TN F32+=TF32*TF32 -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x256x16_F32TF32TF32_SS_TN -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = float[128]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - float & d000, float & d001, float & d002, float & d003, - float & d004, float & d005, float & d006, float & d007, - float & d008, float & d009, float & d010, float & d011, - float & d012, float & d013, float & d014, float & d015, - float & d016, float & d017, float & d018, float & d019, - float & d020, float & d021, float & d022, float & d023, - float & d024, float & d025, float & d026, float & d027, - float & d028, float & d029, float & d030, float & d031, - float & d032, float & d033, float & d034, float & d035, - float & d036, float & d037, float & d038, float & d039, - float & d040, float & d041, float & d042, float & d043, - float & d044, float & d045, float & d046, float & d047, - float & d048, float & d049, float & d050, float & d051, - float & d052, float & d053, float & d054, float & d055, - float & d056, float & d057, float & d058, float & d059, - float & d060, float & d061, float & d062, float & d063, - float & d064, float & d065, float & d066, float & d067, - float & d068, float & d069, float & d070, float & d071, - float & d072, float & d073, float & d074, float & d075, - float & d076, float & d077, float & d078, float & d079, - float & d080, float & d081, float & d082, float & d083, - float & d084, float & d085, float & d086, float & d087, - float & d088, float & d089, float & d090, float & d091, - float & d092, float & d093, float & d094, float & d095, - float & d096, float & d097, float & d098, float & d099, - float & d100, float & d101, float & d102, float & d103, - float & d104, float & d105, float & d106, float & d107, - float & d108, float & d109, float & d110, float & d111, - float & d112, float & d113, float & d114, float & d115, - float & d116, float & d117, float & d118, float & d119, - float & d120, float & d121, float & d122, float & d123, - float & d124, float & d125, float & d126, float & d127, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %132, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n256k16.f32.tf32.tf32 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87, " - " %88, %89, %90, %91, %92, %93, %94, %95, " - " %96, %97, %98, %99, %100, %101, %102, %103, " - " %104, %105, %106, %107, %108, %109, %110, %111, " - " %112, %113, %114, %115, %116, %117, %118, %119, " - " %120, %121, %122, %123, %124, %125, %126, %127}," - " %128," - " %129," - " %130, %131," - " p, %133, %134;\n" - "}\n" - : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), - "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), - "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), - "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), - "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), - "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), - "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), - "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), - "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), - "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), - "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), - "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), - "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), - "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), - "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), - "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), - "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), - "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), - "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), - "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), - "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), - "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), - "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), - "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), - "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), - "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), - "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), - "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), - "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), - "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), - "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123), - "+f"(d124), "+f"(d125), "+f"(d126), "+f"(d127) - : "l"(desc_a), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x256x16_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// SPARSE GMMA 64x256x16 TN F32+=TF32*TF32 -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x256x16_F32TF32TF32_RS_TN -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = float[128]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, - uint64_t const& desc_b, - float & d000, float & d001, float & d002, float & d003, - float & d004, float & d005, float & d006, float & d007, - float & d008, float & d009, float & d010, float & d011, - float & d012, float & d013, float & d014, float & d015, - float & d016, float & d017, float & d018, float & d019, - float & d020, float & d021, float & d022, float & d023, - float & d024, float & d025, float & d026, float & d027, - float & d028, float & d029, float & d030, float & d031, - float & d032, float & d033, float & d034, float & d035, - float & d036, float & d037, float & d038, float & d039, - float & d040, float & d041, float & d042, float & d043, - float & d044, float & d045, float & d046, float & d047, - float & d048, float & d049, float & d050, float & d051, - float & d052, float & d053, float & d054, float & d055, - float & d056, float & d057, float & d058, float & d059, - float & d060, float & d061, float & d062, float & d063, - float & d064, float & d065, float & d066, float & d067, - float & d068, float & d069, float & d070, float & d071, - float & d072, float & d073, float & d074, float & d075, - float & d076, float & d077, float & d078, float & d079, - float & d080, float & d081, float & d082, float & d083, - float & d084, float & d085, float & d086, float & d087, - float & d088, float & d089, float & d090, float & d091, - float & d092, float & d093, float & d094, float & d095, - float & d096, float & d097, float & d098, float & d099, - float & d100, float & d101, float & d102, float & d103, - float & d104, float & d105, float & d106, float & d107, - float & d108, float & d109, float & d110, float & d111, - float & d112, float & d113, float & d114, float & d115, - float & d116, float & d117, float & d118, float & d119, - float & d120, float & d121, float & d122, float & d123, - float & d124, float & d125, float & d126, float & d127, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %135, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n256k16.f32.tf32.tf32 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87, " - " %88, %89, %90, %91, %92, %93, %94, %95, " - " %96, %97, %98, %99, %100, %101, %102, %103, " - " %104, %105, %106, %107, %108, %109, %110, %111, " - " %112, %113, %114, %115, %116, %117, %118, %119, " - " %120, %121, %122, %123, %124, %125, %126, %127}," - "{%128, %129, %130, %131}," - " %132," - " %133, %134," - " p, %136, %137;\n" - "}\n" - : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), - "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), - "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), - "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), - "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), - "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), - "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), - "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), - "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), - "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), - "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), - "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), - "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), - "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), - "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), - "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), - "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), - "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), - "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), - "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), - "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), - "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), - "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), - "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), - "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), - "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), - "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), - "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), - "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), - "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), - "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123), - "+f"(d124), "+f"(d125), "+f"(d126), "+f"(d127) - : "r"(a000), "r"(a001), "r"(a002), "r"(a003), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x256x16_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// SPARSE GMMA 64x8x64 TN S32+=S8*S8 -template < - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x8x64_S32S8S8_SS_TN -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[4]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %8, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n8k64.s32.s8.s8 " - "{%0, %1, %2, %3}," - " %4," - " %5," - " %6, %7," - " p;\n" - "}\n" - : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) - : "l"(desc_a), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x8x64_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// SPARSE GMMA 64x8x64 TN S32+=S8*S8 -template < - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x8x64_S32S8S8_SS_TN_SATURATE -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[4]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %8, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n8k64.s32.s8.s8.satfinite " - "{%0, %1, %2, %3}," - " %4," - " %5," - " %6, %7," - " p;\n" - "}\n" - : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) - : "l"(desc_a), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x8x64_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// SPARSE GMMA 64x16x64 TN S32+=S8*S8 -template < - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x16x64_S32S8S8_SS_TN -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[8]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, - uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %12, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n16k64.s32.s8.s8 " - "{%0, %1, %2, %3, %4, %5, %6, %7}," - " %8," - " %9," - " %10, %11," - " p;\n" - "}\n" - : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), - "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) - : "l"(desc_a), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x16x64_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// SPARSE GMMA 64x16x64 TN S32+=S8*S8 -template < - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x16x64_S32S8S8_SS_TN_SATURATE -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[8]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, - uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %12, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n16k64.s32.s8.s8.satfinite " - "{%0, %1, %2, %3, %4, %5, %6, %7}," - " %8," - " %9," - " %10, %11," - " p;\n" - "}\n" - : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), - "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) - : "l"(desc_a), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x16x64_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// SPARSE GMMA 64x32x64 TN S32+=S8*S8 -template < - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x32x64_S32S8S8_SS_TN -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[16]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %20, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n32k64.s32.s8.s8 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15}," - " %16," - " %17," - " %18, %19," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) - : "l"(desc_a), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x32x64_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// SPARSE GMMA 64x32x64 TN S32+=S8*S8 -template < - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x32x64_S32S8S8_SS_TN_SATURATE -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[16]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %20, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n32k64.s32.s8.s8.satfinite " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15}," - " %16," - " %17," - " %18, %19," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) - : "l"(desc_a), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x32x64_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x48x64 TN S32+=S8*S8 -template < - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x48x64_S32S8S8_SS_TN -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[24]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %28, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n48k64.s32.s8.s8 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23}," - " %24," - " %25," - " %26, %27," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) - : "l"(desc_a), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x48x64_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x48x64 TN S32+=S8*S8 -template < - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x48x64_S32S8S8_SS_TN_SATURATE -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[24]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %28, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n48k64.s32.s8.s8.satfinite " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23}," - " %24," - " %25," - " %26, %27," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) - : "l"(desc_a), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x48x64_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// SPARSE GMMA 64x64x64 TN S32+=S8*S8 -template < - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x64x64_S32S8S8_SS_TN -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[32]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %36, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n64k64.s32.s8.s8 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31}," - " %32," - " %33," - " %34, %35," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) - : "l"(desc_a), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x64x64_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// SPARSE GMMA 64x64x64 TN S32+=S8*S8 -template < - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x64x64_S32S8S8_SS_TN_SATURATE -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[32]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %36, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n64k64.s32.s8.s8.satfinite " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31}," - " %32," - " %33," - " %34, %35," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) - : "l"(desc_a), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x64x64_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x80x64 TN S32+=S8*S8 -template < - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x80x64_S32S8S8_SS_TN -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[40]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %44, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n80k64.s32.s8.s8 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39}," - " %40," - " %41," - " %42, %43," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) - : "l"(desc_a), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x80x64_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x80x64 TN S32+=S8*S8 -template < - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x80x64_S32S8S8_SS_TN_SATURATE -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[40]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %44, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n80k64.s32.s8.s8.satfinite " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39}," - " %40," - " %41," - " %42, %43," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) - : "l"(desc_a), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x80x64_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// SPARSE GMMA 64x96x64 TN S32+=S8*S8 -template < - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x96x64_S32S8S8_SS_TN -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[48]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %52, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n96k64.s32.s8.s8 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47}," - " %48," - " %49," - " %50, %51," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), - "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) - : "l"(desc_a), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x96x64_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// SPARSE GMMA 64x96x64 TN S32+=S8*S8 -template < - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x96x64_S32S8S8_SS_TN_SATURATE -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[48]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %52, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n96k64.s32.s8.s8.satfinite " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47}," - " %48," - " %49," - " %50, %51," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), - "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) - : "l"(desc_a), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x96x64_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x112x64 TN S32+=S8*S8 -template < - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x112x64_S32S8S8_SS_TN -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[56]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, - uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, - uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %60, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n112k64.s32.s8.s8 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55}," - " %56," - " %57," - " %58, %59," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), - "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), - "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), - "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) - : "l"(desc_a), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x112x64_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x112x64 TN S32+=S8*S8 -template < - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x112x64_S32S8S8_SS_TN_SATURATE -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[56]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, - uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, - uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %60, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n112k64.s32.s8.s8.satfinite " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55}," - " %56," - " %57," - " %58, %59," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), - "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), - "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), - "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) - : "l"(desc_a), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x112x64_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// SPARSE GMMA 64x128x64 TN S32+=S8*S8 -template < - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x128x64_S32S8S8_SS_TN -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[64]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, - uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, - uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, - uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, - uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %68, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n128k64.s32.s8.s8 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63}," - " %64," - " %65," - " %66, %67," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), - "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), - "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), - "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), - "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), - "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) - : "l"(desc_a), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x128x64_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// SPARSE GMMA 64x128x64 TN S32+=S8*S8 -template < - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x128x64_S32S8S8_SS_TN_SATURATE -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[64]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, - uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, - uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, - uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, - uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %68, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n128k64.s32.s8.s8.satfinite " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63}," - " %64," - " %65," - " %66, %67," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), - "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), - "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), - "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), - "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), - "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) - : "l"(desc_a), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x128x64_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x144x64 TN S32+=S8*S8 -template < - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x144x64_S32S8S8_SS_TN -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[72]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, - uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, - uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, - uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, - uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, - uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, - uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %76, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n144k64.s32.s8.s8 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71}," - " %72," - " %73," - " %74, %75," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), - "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), - "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), - "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), - "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), - "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), - "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), - "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71) - : "l"(desc_a), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x144x64_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x144x64 TN S32+=S8*S8 -template < - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x144x64_S32S8S8_SS_TN_SATURATE -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[72]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, - uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, - uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, - uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, - uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, - uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, - uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %76, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n144k64.s32.s8.s8.satfinite " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71}," - " %72," - " %73," - " %74, %75," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), - "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), - "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), - "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), - "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), - "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), - "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), - "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71) - : "l"(desc_a), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x144x64_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x160x64 TN S32+=S8*S8 -template < - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x160x64_S32S8S8_SS_TN -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[80]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, - uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, - uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, - uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, - uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, - uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, - uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, - uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, - uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %84, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n160k64.s32.s8.s8 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79}," - " %80," - " %81," - " %82, %83," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), - "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), - "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), - "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), - "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), - "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), - "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), - "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), - "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), - "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79) - : "l"(desc_a), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x160x64_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x160x64 TN S32+=S8*S8 -template < - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x160x64_S32S8S8_SS_TN_SATURATE -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[80]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, - uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, - uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, - uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, - uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, - uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, - uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, - uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, - uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %84, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n160k64.s32.s8.s8.satfinite " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79}," - " %80," - " %81," - " %82, %83," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), - "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), - "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), - "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), - "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), - "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), - "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), - "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), - "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), - "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79) - : "l"(desc_a), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x160x64_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x176x64 TN S32+=S8*S8 -template < - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x176x64_S32S8S8_SS_TN -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[88]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, - uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, - uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, - uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, - uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, - uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, - uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, - uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, - uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, - uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, - uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %92, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n176k64.s32.s8.s8 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87}," - " %88," - " %89," - " %90, %91," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), - "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), - "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), - "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), - "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), - "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), - "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), - "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), - "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), - "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), - "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), - "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87) - : "l"(desc_a), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x176x64_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x176x64 TN S32+=S8*S8 -template < - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x176x64_S32S8S8_SS_TN_SATURATE -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[88]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, - uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, - uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, - uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, - uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, - uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, - uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, - uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, - uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, - uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, - uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %92, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n176k64.s32.s8.s8.satfinite " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87}," - " %88," - " %89," - " %90, %91," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), - "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), - "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), - "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), - "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), - "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), - "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), - "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), - "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), - "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), - "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), - "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87) - : "l"(desc_a), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x176x64_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// SPARSE GMMA 64x192x64 TN S32+=S8*S8 -template < - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x192x64_S32S8S8_SS_TN -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[96]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, - uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, - uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, - uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, - uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, - uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, - uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, - uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, - uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, - uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, - uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, - uint32_t & d88, uint32_t & d89, uint32_t & d90, uint32_t & d91, - uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %100, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n192k64.s32.s8.s8 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87, " - " %88, %89, %90, %91, %92, %93, %94, %95}," - " %96," - " %97," - " %98, %99," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), - "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), - "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), - "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), - "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), - "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), - "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), - "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), - "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), - "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), - "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), - "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87), - "+r"(d88), "+r"(d89), "+r"(d90), "+r"(d91), - "+r"(d92), "+r"(d93), "+r"(d94), "+r"(d95) - : "l"(desc_a), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x192x64_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// SPARSE GMMA 64x192x64 TN S32+=S8*S8 -template < - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x192x64_S32S8S8_SS_TN_SATURATE -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[96]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, - uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, - uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, - uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, - uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, - uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, - uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, - uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, - uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, - uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, - uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, - uint32_t & d88, uint32_t & d89, uint32_t & d90, uint32_t & d91, - uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %100, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n192k64.s32.s8.s8.satfinite " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87, " - " %88, %89, %90, %91, %92, %93, %94, %95}," - " %96," - " %97," - " %98, %99," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), - "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), - "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), - "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), - "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), - "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), - "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), - "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), - "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), - "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), - "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), - "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87), - "+r"(d88), "+r"(d89), "+r"(d90), "+r"(d91), - "+r"(d92), "+r"(d93), "+r"(d94), "+r"(d95) - : "l"(desc_a), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x192x64_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x208x64 TN S32+=S8*S8 -template < - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x208x64_S32S8S8_SS_TN -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[104]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, - uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, - uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, - uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, - uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, - uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, - uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, - uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, - uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, - uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, - uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, - uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, - uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, - uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, - uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, - uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, - uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, - uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, - uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, - uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, - uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, - uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, - uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, - uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, - uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, - uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %108, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n208k64.s32.s8.s8 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87, " - " %88, %89, %90, %91, %92, %93, %94, %95, " - " %96, %97, %98, %99, %100, %101, %102, %103}," - " %104," - " %105," - " %106, %107," - " p;\n" - "}\n" - : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), - "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), - "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), - "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), - "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), - "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), - "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), - "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), - "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), - "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), - "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), - "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), - "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), - "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), - "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), - "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), - "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), - "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), - "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), - "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), - "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), - "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), - "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), - "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), - "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), - "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103) - : "l"(desc_a), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x208x64_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x208x64 TN S32+=S8*S8 -template < - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x208x64_S32S8S8_SS_TN_SATURATE -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[104]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, - uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, - uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, - uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, - uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, - uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, - uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, - uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, - uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, - uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, - uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, - uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, - uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, - uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, - uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, - uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, - uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, - uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, - uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, - uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, - uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, - uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, - uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, - uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, - uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, - uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %108, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n208k64.s32.s8.s8.satfinite " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87, " - " %88, %89, %90, %91, %92, %93, %94, %95, " - " %96, %97, %98, %99, %100, %101, %102, %103}," - " %104," - " %105," - " %106, %107," - " p;\n" - "}\n" - : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), - "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), - "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), - "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), - "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), - "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), - "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), - "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), - "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), - "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), - "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), - "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), - "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), - "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), - "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), - "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), - "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), - "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), - "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), - "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), - "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), - "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), - "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), - "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), - "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), - "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103) - : "l"(desc_a), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x208x64_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x224x64 TN S32+=S8*S8 -template < - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x224x64_S32S8S8_SS_TN -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[112]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, - uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, - uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, - uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, - uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, - uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, - uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, - uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, - uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, - uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, - uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, - uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, - uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, - uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, - uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, - uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, - uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, - uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, - uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, - uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, - uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, - uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, - uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, - uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, - uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, - uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, - uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, - uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %116, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n224k64.s32.s8.s8 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87, " - " %88, %89, %90, %91, %92, %93, %94, %95, " - " %96, %97, %98, %99, %100, %101, %102, %103, " - " %104, %105, %106, %107, %108, %109, %110, %111}," - " %112," - " %113," - " %114, %115," - " p;\n" - "}\n" - : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), - "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), - "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), - "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), - "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), - "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), - "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), - "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), - "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), - "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), - "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), - "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), - "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), - "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), - "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), - "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), - "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), - "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), - "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), - "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), - "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), - "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), - "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), - "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), - "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), - "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), - "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), - "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111) - : "l"(desc_a), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x224x64_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x224x64 TN S32+=S8*S8 -template < - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x224x64_S32S8S8_SS_TN_SATURATE -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[112]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, - uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, - uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, - uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, - uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, - uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, - uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, - uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, - uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, - uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, - uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, - uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, - uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, - uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, - uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, - uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, - uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, - uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, - uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, - uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, - uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, - uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, - uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, - uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, - uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, - uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, - uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, - uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %116, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n224k64.s32.s8.s8.satfinite " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87, " - " %88, %89, %90, %91, %92, %93, %94, %95, " - " %96, %97, %98, %99, %100, %101, %102, %103, " - " %104, %105, %106, %107, %108, %109, %110, %111}," - " %112," - " %113," - " %114, %115," - " p;\n" - "}\n" - : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), - "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), - "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), - "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), - "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), - "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), - "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), - "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), - "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), - "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), - "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), - "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), - "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), - "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), - "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), - "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), - "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), - "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), - "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), - "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), - "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), - "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), - "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), - "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), - "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), - "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), - "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), - "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111) - : "l"(desc_a), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x224x64_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x240x64 TN S32+=S8*S8 -template < - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x240x64_S32S8S8_SS_TN -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[120]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, - uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, - uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, - uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, - uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, - uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, - uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, - uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, - uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, - uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, - uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, - uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, - uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, - uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, - uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, - uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, - uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, - uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, - uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, - uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, - uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, - uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, - uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, - uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, - uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, - uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, - uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, - uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, - uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, - uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %124, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n240k64.s32.s8.s8 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87, " - " %88, %89, %90, %91, %92, %93, %94, %95, " - " %96, %97, %98, %99, %100, %101, %102, %103, " - " %104, %105, %106, %107, %108, %109, %110, %111, " - " %112, %113, %114, %115, %116, %117, %118, %119}," - " %120," - " %121," - " %122, %123," - " p;\n" - "}\n" - : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), - "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), - "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), - "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), - "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), - "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), - "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), - "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), - "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), - "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), - "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), - "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), - "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), - "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), - "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), - "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), - "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), - "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), - "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), - "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), - "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), - "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), - "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), - "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), - "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), - "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), - "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), - "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), - "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), - "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119) - : "l"(desc_a), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x240x64_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x240x64 TN S32+=S8*S8 -template < - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x240x64_S32S8S8_SS_TN_SATURATE -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[120]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, - uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, - uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, - uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, - uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, - uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, - uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, - uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, - uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, - uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, - uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, - uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, - uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, - uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, - uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, - uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, - uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, - uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, - uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, - uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, - uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, - uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, - uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, - uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, - uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, - uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, - uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, - uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, - uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, - uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %124, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n240k64.s32.s8.s8.satfinite " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87, " - " %88, %89, %90, %91, %92, %93, %94, %95, " - " %96, %97, %98, %99, %100, %101, %102, %103, " - " %104, %105, %106, %107, %108, %109, %110, %111, " - " %112, %113, %114, %115, %116, %117, %118, %119}," - " %120," - " %121," - " %122, %123," - " p;\n" - "}\n" - : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), - "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), - "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), - "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), - "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), - "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), - "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), - "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), - "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), - "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), - "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), - "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), - "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), - "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), - "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), - "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), - "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), - "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), - "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), - "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), - "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), - "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), - "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), - "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), - "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), - "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), - "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), - "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), - "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), - "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119) - : "l"(desc_a), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x240x64_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// SPARSE GMMA 64x256x64 TN S32+=S8*S8 -template < - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x256x64_S32S8S8_SS_TN -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[128]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, - uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, - uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, - uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, - uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, - uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, - uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, - uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, - uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, - uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, - uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, - uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, - uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, - uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, - uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, - uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, - uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, - uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, - uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, - uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, - uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, - uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, - uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, - uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, - uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, - uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, - uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, - uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, - uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, - uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, - uint32_t & d120, uint32_t & d121, uint32_t & d122, uint32_t & d123, - uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %132, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n256k64.s32.s8.s8 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87, " - " %88, %89, %90, %91, %92, %93, %94, %95, " - " %96, %97, %98, %99, %100, %101, %102, %103, " - " %104, %105, %106, %107, %108, %109, %110, %111, " - " %112, %113, %114, %115, %116, %117, %118, %119, " - " %120, %121, %122, %123, %124, %125, %126, %127}," - " %128," - " %129," - " %130, %131," - " p;\n" - "}\n" - : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), - "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), - "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), - "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), - "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), - "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), - "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), - "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), - "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), - "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), - "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), - "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), - "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), - "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), - "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), - "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), - "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), - "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), - "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), - "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), - "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), - "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), - "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), - "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), - "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), - "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), - "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), - "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), - "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), - "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119), - "+r"(d120), "+r"(d121), "+r"(d122), "+r"(d123), - "+r"(d124), "+r"(d125), "+r"(d126), "+r"(d127) - : "l"(desc_a), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x256x64_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// SPARSE GMMA 64x256x64 TN S32+=S8*S8 -template < - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x256x64_S32S8S8_SS_TN_SATURATE -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[128]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, - uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, - uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, - uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, - uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, - uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, - uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, - uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, - uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, - uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, - uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, - uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, - uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, - uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, - uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, - uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, - uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, - uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, - uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, - uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, - uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, - uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, - uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, - uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, - uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, - uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, - uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, - uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, - uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, - uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, - uint32_t & d120, uint32_t & d121, uint32_t & d122, uint32_t & d123, - uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %132, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n256k64.s32.s8.s8.satfinite " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87, " - " %88, %89, %90, %91, %92, %93, %94, %95, " - " %96, %97, %98, %99, %100, %101, %102, %103, " - " %104, %105, %106, %107, %108, %109, %110, %111, " - " %112, %113, %114, %115, %116, %117, %118, %119, " - " %120, %121, %122, %123, %124, %125, %126, %127}," - " %128," - " %129," - " %130, %131," - " p;\n" - "}\n" - : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), - "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), - "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), - "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), - "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), - "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), - "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), - "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), - "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), - "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), - "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), - "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), - "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), - "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), - "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), - "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), - "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), - "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), - "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), - "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), - "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), - "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), - "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), - "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), - "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), - "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), - "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), - "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), - "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), - "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119), - "+r"(d120), "+r"(d121), "+r"(d122), "+r"(d123), - "+r"(d124), "+r"(d125), "+r"(d126), "+r"(d127) - : "l"(desc_a), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x256x64_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// SPARSE GMMA 64x8x64 TN S32+=S8*S8 -template < - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x8x64_S32S8S8_RS_TN -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[4]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, - uint64_t const& desc_b, - uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %11, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n8k64.s32.s8.s8 " - "{%0, %1, %2, %3}," - "{%4, %5, %6, %7}," - " %8," - " %9, %10," - " p;\n" - "}\n" - : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) - : "r"(a0), "r"(a1), "r"(a2), "r"(a3), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x8x64_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// SPARSE GMMA 64x8x64 TN S32+=S8*S8 -template < - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x8x64_S32S8S8_RS_TN_SATURATE -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[4]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, - uint64_t const& desc_b, - uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %11, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n8k64.s32.s8.s8.satfinite " - "{%0, %1, %2, %3}," - "{%4, %5, %6, %7}," - " %8," - " %9, %10," - " p;\n" - "}\n" - : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) - : "r"(a0), "r"(a1), "r"(a2), "r"(a3), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x8x64_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// SPARSE GMMA 64x16x64 TN S32+=S8*S8 -template < - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x16x64_S32S8S8_RS_TN -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[8]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, - uint64_t const& desc_b, - uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, - uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %15, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n16k64.s32.s8.s8 " - "{%0, %1, %2, %3, %4, %5, %6, %7}," - "{%8, %9, %10, %11}," - " %12," - " %13, %14," - " p;\n" - "}\n" - : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), - "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) - : "r"(a0), "r"(a1), "r"(a2), "r"(a3), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x16x64_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// SPARSE GMMA 64x16x64 TN S32+=S8*S8 -template < - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x16x64_S32S8S8_RS_TN_SATURATE -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[8]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, - uint64_t const& desc_b, - uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, - uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %15, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n16k64.s32.s8.s8.satfinite " - "{%0, %1, %2, %3, %4, %5, %6, %7}," - "{%8, %9, %10, %11}," - " %12," - " %13, %14," - " p;\n" - "}\n" - : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), - "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) - : "r"(a0), "r"(a1), "r"(a2), "r"(a3), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x16x64_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// SPARSE GMMA 64x32x64 TN S32+=S8*S8 -template < - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x32x64_S32S8S8_RS_TN -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[16]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %23, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n32k64.s32.s8.s8 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15}," - "{%16, %17, %18, %19}," - " %20," - " %21, %22," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x32x64_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// SPARSE GMMA 64x32x64 TN S32+=S8*S8 -template < - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x32x64_S32S8S8_RS_TN_SATURATE -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[16]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %23, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n32k64.s32.s8.s8.satfinite " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15}," - "{%16, %17, %18, %19}," - " %20," - " %21, %22," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x32x64_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x48x64 TN S32+=S8*S8 -template < - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x48x64_S32S8S8_RS_TN -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[24]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %31, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n48k64.s32.s8.s8 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23}," - "{%24, %25, %26, %27}," - " %28," - " %29, %30," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x48x64_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x48x64 TN S32+=S8*S8 -template < - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x48x64_S32S8S8_RS_TN_SATURATE -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[24]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %31, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n48k64.s32.s8.s8.satfinite " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23}," - "{%24, %25, %26, %27}," - " %28," - " %29, %30," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x48x64_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// SPARSE GMMA 64x64x64 TN S32+=S8*S8 -template < - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x64x64_S32S8S8_RS_TN -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[32]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %39, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n64k64.s32.s8.s8 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31}," - "{%32, %33, %34, %35}," - " %36," - " %37, %38," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x64x64_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// SPARSE GMMA 64x64x64 TN S32+=S8*S8 -template < - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x64x64_S32S8S8_RS_TN_SATURATE -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[32]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %39, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n64k64.s32.s8.s8.satfinite " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31}," - "{%32, %33, %34, %35}," - " %36," - " %37, %38," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x64x64_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x80x64 TN S32+=S8*S8 -template < - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x80x64_S32S8S8_RS_TN -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[40]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %47, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n80k64.s32.s8.s8 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39}," - "{%40, %41, %42, %43}," - " %44," - " %45, %46," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x80x64_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x80x64 TN S32+=S8*S8 -template < - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x80x64_S32S8S8_RS_TN_SATURATE -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[40]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %47, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n80k64.s32.s8.s8.satfinite " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39}," - "{%40, %41, %42, %43}," - " %44," - " %45, %46," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x80x64_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// SPARSE GMMA 64x96x64 TN S32+=S8*S8 -template < - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x96x64_S32S8S8_RS_TN -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[48]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %55, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n96k64.s32.s8.s8 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47}," - "{%48, %49, %50, %51}," - " %52," - " %53, %54," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), - "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x96x64_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// SPARSE GMMA 64x96x64 TN S32+=S8*S8 -template < - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x96x64_S32S8S8_RS_TN_SATURATE -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[48]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %55, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n96k64.s32.s8.s8.satfinite " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47}," - "{%48, %49, %50, %51}," - " %52," - " %53, %54," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), - "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x96x64_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x112x64 TN S32+=S8*S8 -template < - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x112x64_S32S8S8_RS_TN -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[56]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, - uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, - uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %63, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n112k64.s32.s8.s8 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55}," - "{%56, %57, %58, %59}," - " %60," - " %61, %62," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), - "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), - "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), - "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x112x64_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x112x64 TN S32+=S8*S8 -template < - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x112x64_S32S8S8_RS_TN_SATURATE -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[56]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, - uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, - uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %63, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n112k64.s32.s8.s8.satfinite " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55}," - "{%56, %57, %58, %59}," - " %60," - " %61, %62," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), - "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), - "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), - "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x112x64_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// SPARSE GMMA 64x128x64 TN S32+=S8*S8 -template < - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x128x64_S32S8S8_RS_TN -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[64]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, - uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, - uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, - uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, - uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %71, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n128k64.s32.s8.s8 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63}," - "{%64, %65, %66, %67}," - " %68," - " %69, %70," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), - "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), - "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), - "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), - "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), - "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x128x64_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// SPARSE GMMA 64x128x64 TN S32+=S8*S8 -template < - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x128x64_S32S8S8_RS_TN_SATURATE -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[64]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, - uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, - uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, - uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, - uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %71, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n128k64.s32.s8.s8.satfinite " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63}," - "{%64, %65, %66, %67}," - " %68," - " %69, %70," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), - "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), - "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), - "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), - "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), - "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x128x64_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x144x64 TN S32+=S8*S8 -template < - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x144x64_S32S8S8_RS_TN -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[72]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, - uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, - uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, - uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, - uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, - uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, - uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %79, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n144k64.s32.s8.s8 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71}," - "{%72, %73, %74, %75}," - " %76," - " %77, %78," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), - "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), - "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), - "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), - "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), - "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), - "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), - "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x144x64_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x144x64 TN S32+=S8*S8 -template < - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x144x64_S32S8S8_RS_TN_SATURATE -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[72]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, - uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, - uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, - uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, - uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, - uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, - uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %79, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n144k64.s32.s8.s8.satfinite " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71}," - "{%72, %73, %74, %75}," - " %76," - " %77, %78," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), - "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), - "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), - "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), - "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), - "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), - "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), - "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x144x64_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x160x64 TN S32+=S8*S8 -template < - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x160x64_S32S8S8_RS_TN -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[80]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, - uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, - uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, - uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, - uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, - uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, - uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, - uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, - uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %87, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n160k64.s32.s8.s8 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79}," - "{%80, %81, %82, %83}," - " %84," - " %85, %86," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), - "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), - "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), - "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), - "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), - "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), - "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), - "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), - "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), - "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x160x64_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x160x64 TN S32+=S8*S8 -template < - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x160x64_S32S8S8_RS_TN_SATURATE -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[80]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, - uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, - uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, - uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, - uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, - uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, - uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, - uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, - uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %87, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n160k64.s32.s8.s8.satfinite " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79}," - "{%80, %81, %82, %83}," - " %84," - " %85, %86," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), - "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), - "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), - "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), - "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), - "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), - "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), - "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), - "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), - "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x160x64_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x176x64 TN S32+=S8*S8 -template < - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x176x64_S32S8S8_RS_TN -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[88]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, - uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, - uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, - uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, - uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, - uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, - uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, - uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, - uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, - uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, - uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %95, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n176k64.s32.s8.s8 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87}," - "{%88, %89, %90, %91}," - " %92," - " %93, %94," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), - "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), - "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), - "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), - "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), - "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), - "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), - "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), - "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), - "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), - "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), - "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x176x64_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x176x64 TN S32+=S8*S8 -template < - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x176x64_S32S8S8_RS_TN_SATURATE -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[88]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, - uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, - uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, - uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, - uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, - uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, - uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, - uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, - uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, - uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, - uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %95, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n176k64.s32.s8.s8.satfinite " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87}," - "{%88, %89, %90, %91}," - " %92," - " %93, %94," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), - "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), - "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), - "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), - "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), - "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), - "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), - "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), - "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), - "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), - "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), - "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x176x64_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// SPARSE GMMA 64x192x64 TN S32+=S8*S8 -template < - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x192x64_S32S8S8_RS_TN -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[96]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, - uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, - uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, - uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, - uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, - uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, - uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, - uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, - uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, - uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, - uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, - uint32_t & d88, uint32_t & d89, uint32_t & d90, uint32_t & d91, - uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %103, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n192k64.s32.s8.s8 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87, " - " %88, %89, %90, %91, %92, %93, %94, %95}," - "{%96, %97, %98, %99}," - " %100," - " %101, %102," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), - "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), - "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), - "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), - "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), - "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), - "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), - "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), - "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), - "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), - "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), - "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87), - "+r"(d88), "+r"(d89), "+r"(d90), "+r"(d91), - "+r"(d92), "+r"(d93), "+r"(d94), "+r"(d95) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x192x64_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// SPARSE GMMA 64x192x64 TN S32+=S8*S8 -template < - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x192x64_S32S8S8_RS_TN_SATURATE -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[96]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, - uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, - uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, - uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, - uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, - uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, - uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, - uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, - uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, - uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, - uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, - uint32_t & d88, uint32_t & d89, uint32_t & d90, uint32_t & d91, - uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %103, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n192k64.s32.s8.s8.satfinite " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87, " - " %88, %89, %90, %91, %92, %93, %94, %95}," - "{%96, %97, %98, %99}," - " %100," - " %101, %102," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), - "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), - "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), - "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), - "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), - "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), - "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), - "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), - "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), - "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), - "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), - "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87), - "+r"(d88), "+r"(d89), "+r"(d90), "+r"(d91), - "+r"(d92), "+r"(d93), "+r"(d94), "+r"(d95) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x192x64_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x208x64 TN S32+=S8*S8 -template < - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x208x64_S32S8S8_RS_TN -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[104]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, - uint64_t const& desc_b, - uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, - uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, - uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, - uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, - uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, - uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, - uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, - uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, - uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, - uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, - uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, - uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, - uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, - uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, - uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, - uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, - uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, - uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, - uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, - uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, - uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, - uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, - uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, - uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, - uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, - uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %111, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n208k64.s32.s8.s8 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87, " - " %88, %89, %90, %91, %92, %93, %94, %95, " - " %96, %97, %98, %99, %100, %101, %102, %103}," - "{%104, %105, %106, %107}," - " %108," - " %109, %110," - " p;\n" - "}\n" - : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), - "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), - "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), - "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), - "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), - "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), - "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), - "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), - "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), - "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), - "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), - "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), - "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), - "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), - "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), - "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), - "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), - "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), - "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), - "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), - "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), - "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), - "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), - "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), - "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), - "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103) - : "r"(a000), "r"(a001), "r"(a002), "r"(a003), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x208x64_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x208x64 TN S32+=S8*S8 -template < - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x208x64_S32S8S8_RS_TN_SATURATE -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[104]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, - uint64_t const& desc_b, - uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, - uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, - uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, - uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, - uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, - uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, - uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, - uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, - uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, - uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, - uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, - uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, - uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, - uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, - uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, - uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, - uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, - uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, - uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, - uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, - uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, - uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, - uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, - uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, - uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, - uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %111, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n208k64.s32.s8.s8.satfinite " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87, " - " %88, %89, %90, %91, %92, %93, %94, %95, " - " %96, %97, %98, %99, %100, %101, %102, %103}," - "{%104, %105, %106, %107}," - " %108," - " %109, %110," - " p;\n" - "}\n" - : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), - "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), - "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), - "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), - "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), - "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), - "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), - "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), - "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), - "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), - "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), - "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), - "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), - "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), - "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), - "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), - "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), - "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), - "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), - "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), - "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), - "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), - "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), - "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), - "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), - "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103) - : "r"(a000), "r"(a001), "r"(a002), "r"(a003), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x208x64_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x224x64 TN S32+=S8*S8 -template < - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x224x64_S32S8S8_RS_TN -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[112]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, - uint64_t const& desc_b, - uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, - uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, - uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, - uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, - uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, - uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, - uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, - uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, - uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, - uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, - uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, - uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, - uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, - uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, - uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, - uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, - uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, - uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, - uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, - uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, - uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, - uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, - uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, - uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, - uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, - uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, - uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, - uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %119, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n224k64.s32.s8.s8 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87, " - " %88, %89, %90, %91, %92, %93, %94, %95, " - " %96, %97, %98, %99, %100, %101, %102, %103, " - " %104, %105, %106, %107, %108, %109, %110, %111}," - "{%112, %113, %114, %115}," - " %116," - " %117, %118," - " p;\n" - "}\n" - : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), - "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), - "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), - "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), - "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), - "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), - "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), - "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), - "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), - "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), - "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), - "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), - "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), - "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), - "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), - "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), - "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), - "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), - "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), - "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), - "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), - "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), - "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), - "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), - "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), - "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), - "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), - "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111) - : "r"(a000), "r"(a001), "r"(a002), "r"(a003), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x224x64_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x224x64 TN S32+=S8*S8 -template < - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x224x64_S32S8S8_RS_TN_SATURATE -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[112]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, - uint64_t const& desc_b, - uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, - uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, - uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, - uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, - uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, - uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, - uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, - uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, - uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, - uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, - uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, - uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, - uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, - uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, - uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, - uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, - uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, - uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, - uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, - uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, - uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, - uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, - uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, - uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, - uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, - uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, - uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, - uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %119, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n224k64.s32.s8.s8.satfinite " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87, " - " %88, %89, %90, %91, %92, %93, %94, %95, " - " %96, %97, %98, %99, %100, %101, %102, %103, " - " %104, %105, %106, %107, %108, %109, %110, %111}," - "{%112, %113, %114, %115}," - " %116," - " %117, %118," - " p;\n" - "}\n" - : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), - "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), - "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), - "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), - "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), - "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), - "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), - "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), - "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), - "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), - "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), - "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), - "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), - "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), - "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), - "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), - "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), - "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), - "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), - "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), - "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), - "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), - "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), - "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), - "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), - "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), - "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), - "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111) - : "r"(a000), "r"(a001), "r"(a002), "r"(a003), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x224x64_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x240x64 TN S32+=S8*S8 -template < - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x240x64_S32S8S8_RS_TN -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[120]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, - uint64_t const& desc_b, - uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, - uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, - uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, - uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, - uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, - uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, - uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, - uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, - uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, - uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, - uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, - uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, - uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, - uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, - uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, - uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, - uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, - uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, - uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, - uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, - uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, - uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, - uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, - uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, - uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, - uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, - uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, - uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, - uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, - uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %127, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n240k64.s32.s8.s8 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87, " - " %88, %89, %90, %91, %92, %93, %94, %95, " - " %96, %97, %98, %99, %100, %101, %102, %103, " - " %104, %105, %106, %107, %108, %109, %110, %111, " - " %112, %113, %114, %115, %116, %117, %118, %119}," - "{%120, %121, %122, %123}," - " %124," - " %125, %126," - " p;\n" - "}\n" - : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), - "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), - "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), - "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), - "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), - "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), - "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), - "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), - "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), - "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), - "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), - "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), - "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), - "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), - "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), - "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), - "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), - "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), - "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), - "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), - "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), - "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), - "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), - "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), - "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), - "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), - "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), - "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), - "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), - "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119) - : "r"(a000), "r"(a001), "r"(a002), "r"(a003), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x240x64_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x240x64 TN S32+=S8*S8 -template < - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x240x64_S32S8S8_RS_TN_SATURATE -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[120]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, - uint64_t const& desc_b, - uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, - uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, - uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, - uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, - uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, - uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, - uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, - uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, - uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, - uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, - uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, - uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, - uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, - uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, - uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, - uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, - uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, - uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, - uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, - uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, - uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, - uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, - uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, - uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, - uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, - uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, - uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, - uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, - uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, - uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %127, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n240k64.s32.s8.s8.satfinite " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87, " - " %88, %89, %90, %91, %92, %93, %94, %95, " - " %96, %97, %98, %99, %100, %101, %102, %103, " - " %104, %105, %106, %107, %108, %109, %110, %111, " - " %112, %113, %114, %115, %116, %117, %118, %119}," - "{%120, %121, %122, %123}," - " %124," - " %125, %126," - " p;\n" - "}\n" - : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), - "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), - "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), - "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), - "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), - "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), - "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), - "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), - "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), - "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), - "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), - "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), - "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), - "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), - "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), - "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), - "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), - "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), - "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), - "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), - "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), - "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), - "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), - "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), - "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), - "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), - "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), - "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), - "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), - "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119) - : "r"(a000), "r"(a001), "r"(a002), "r"(a003), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x240x64_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// SPARSE GMMA 64x256x64 TN S32+=S8*S8 -template < - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x256x64_S32S8S8_RS_TN -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[128]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, - uint64_t const& desc_b, - uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, - uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, - uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, - uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, - uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, - uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, - uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, - uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, - uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, - uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, - uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, - uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, - uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, - uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, - uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, - uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, - uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, - uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, - uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, - uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, - uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, - uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, - uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, - uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, - uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, - uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, - uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, - uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, - uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, - uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, - uint32_t & d120, uint32_t & d121, uint32_t & d122, uint32_t & d123, - uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %135, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n256k64.s32.s8.s8 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87, " - " %88, %89, %90, %91, %92, %93, %94, %95, " - " %96, %97, %98, %99, %100, %101, %102, %103, " - " %104, %105, %106, %107, %108, %109, %110, %111, " - " %112, %113, %114, %115, %116, %117, %118, %119, " - " %120, %121, %122, %123, %124, %125, %126, %127}," - "{%128, %129, %130, %131}," - " %132," - " %133, %134," - " p;\n" - "}\n" - : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), - "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), - "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), - "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), - "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), - "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), - "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), - "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), - "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), - "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), - "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), - "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), - "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), - "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), - "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), - "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), - "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), - "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), - "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), - "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), - "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), - "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), - "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), - "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), - "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), - "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), - "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), - "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), - "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), - "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119), - "+r"(d120), "+r"(d121), "+r"(d122), "+r"(d123), - "+r"(d124), "+r"(d125), "+r"(d126), "+r"(d127) - : "r"(a000), "r"(a001), "r"(a002), "r"(a003), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x256x64_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// SPARSE GMMA 64x256x64 TN S32+=S8*S8 -template < - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x256x64_S32S8S8_RS_TN_SATURATE -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[128]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, - uint64_t const& desc_b, - uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, - uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, - uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, - uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, - uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, - uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, - uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, - uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, - uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, - uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, - uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, - uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, - uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, - uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, - uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, - uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, - uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, - uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, - uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, - uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, - uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, - uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, - uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, - uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, - uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, - uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, - uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, - uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, - uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, - uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, - uint32_t & d120, uint32_t & d121, uint32_t & d122, uint32_t & d123, - uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %135, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n256k64.s32.s8.s8.satfinite " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87, " - " %88, %89, %90, %91, %92, %93, %94, %95, " - " %96, %97, %98, %99, %100, %101, %102, %103, " - " %104, %105, %106, %107, %108, %109, %110, %111, " - " %112, %113, %114, %115, %116, %117, %118, %119, " - " %120, %121, %122, %123, %124, %125, %126, %127}," - "{%128, %129, %130, %131}," - " %132," - " %133, %134," - " p;\n" - "}\n" - : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), - "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), - "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), - "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), - "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), - "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), - "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), - "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), - "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), - "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), - "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), - "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), - "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), - "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), - "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), - "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), - "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), - "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), - "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), - "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), - "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), - "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), - "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), - "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), - "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), - "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), - "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), - "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), - "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), - "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119), - "+r"(d120), "+r"(d121), "+r"(d122), "+r"(d123), - "+r"(d124), "+r"(d125), "+r"(d126), "+r"(d127) - : "r"(a000), "r"(a001), "r"(a002), "r"(a003), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x256x64_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// SPARSE GMMA 64x8x64 TN S32+=S8*U8 -template < - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x8x64_S32S8U8_SS_TN -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[4]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %8, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n8k64.s32.s8.u8 " - "{%0, %1, %2, %3}," - " %4," - " %5," - " %6, %7," - " p;\n" - "}\n" - : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) - : "l"(desc_a), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x8x64_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// SPARSE GMMA 64x8x64 TN S32+=S8*U8 -template < - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x8x64_S32S8U8_SS_TN_SATURATE -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[4]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %8, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n8k64.s32.s8.u8.satfinite " - "{%0, %1, %2, %3}," - " %4," - " %5," - " %6, %7," - " p;\n" - "}\n" - : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) - : "l"(desc_a), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x8x64_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// SPARSE GMMA 64x16x64 TN S32+=S8*U8 -template < - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x16x64_S32S8U8_SS_TN -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[8]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, - uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %12, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n16k64.s32.s8.u8 " - "{%0, %1, %2, %3, %4, %5, %6, %7}," - " %8," - " %9," - " %10, %11," - " p;\n" - "}\n" - : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), - "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) - : "l"(desc_a), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x16x64_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// SPARSE GMMA 64x16x64 TN S32+=S8*U8 -template < - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x16x64_S32S8U8_SS_TN_SATURATE -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[8]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, - uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %12, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n16k64.s32.s8.u8.satfinite " - "{%0, %1, %2, %3, %4, %5, %6, %7}," - " %8," - " %9," - " %10, %11," - " p;\n" - "}\n" - : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), - "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) - : "l"(desc_a), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x16x64_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// SPARSE GMMA 64x32x64 TN S32+=S8*U8 -template < - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x32x64_S32S8U8_SS_TN -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[16]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %20, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n32k64.s32.s8.u8 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15}," - " %16," - " %17," - " %18, %19," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) - : "l"(desc_a), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x32x64_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// SPARSE GMMA 64x32x64 TN S32+=S8*U8 -template < - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x32x64_S32S8U8_SS_TN_SATURATE -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[16]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %20, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n32k64.s32.s8.u8.satfinite " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15}," - " %16," - " %17," - " %18, %19," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) - : "l"(desc_a), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x32x64_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x48x64 TN S32+=S8*U8 -template < - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x48x64_S32S8U8_SS_TN -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[24]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %28, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n48k64.s32.s8.u8 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23}," - " %24," - " %25," - " %26, %27," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) - : "l"(desc_a), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x48x64_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x48x64 TN S32+=S8*U8 -template < - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x48x64_S32S8U8_SS_TN_SATURATE -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[24]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %28, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n48k64.s32.s8.u8.satfinite " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23}," - " %24," - " %25," - " %26, %27," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) - : "l"(desc_a), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x48x64_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// SPARSE GMMA 64x64x64 TN S32+=S8*U8 -template < - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x64x64_S32S8U8_SS_TN -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[32]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %36, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n64k64.s32.s8.u8 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31}," - " %32," - " %33," - " %34, %35," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) - : "l"(desc_a), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x64x64_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// SPARSE GMMA 64x64x64 TN S32+=S8*U8 -template < - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x64x64_S32S8U8_SS_TN_SATURATE -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[32]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %36, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n64k64.s32.s8.u8.satfinite " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31}," - " %32," - " %33," - " %34, %35," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) - : "l"(desc_a), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x64x64_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x80x64 TN S32+=S8*U8 -template < - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x80x64_S32S8U8_SS_TN -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[40]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %44, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n80k64.s32.s8.u8 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39}," - " %40," - " %41," - " %42, %43," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) - : "l"(desc_a), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x80x64_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x80x64 TN S32+=S8*U8 -template < - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x80x64_S32S8U8_SS_TN_SATURATE -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[40]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %44, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n80k64.s32.s8.u8.satfinite " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39}," - " %40," - " %41," - " %42, %43," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) - : "l"(desc_a), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x80x64_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// SPARSE GMMA 64x96x64 TN S32+=S8*U8 -template < - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x96x64_S32S8U8_SS_TN -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[48]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %52, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n96k64.s32.s8.u8 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47}," - " %48," - " %49," - " %50, %51," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), - "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) - : "l"(desc_a), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x96x64_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// SPARSE GMMA 64x96x64 TN S32+=S8*U8 -template < - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x96x64_S32S8U8_SS_TN_SATURATE -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[48]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %52, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n96k64.s32.s8.u8.satfinite " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47}," - " %48," - " %49," - " %50, %51," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), - "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) - : "l"(desc_a), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x96x64_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x112x64 TN S32+=S8*U8 -template < - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x112x64_S32S8U8_SS_TN -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[56]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, - uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, - uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %60, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n112k64.s32.s8.u8 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55}," - " %56," - " %57," - " %58, %59," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), - "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), - "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), - "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) - : "l"(desc_a), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x112x64_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x112x64 TN S32+=S8*U8 -template < - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x112x64_S32S8U8_SS_TN_SATURATE -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[56]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, - uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, - uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %60, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n112k64.s32.s8.u8.satfinite " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55}," - " %56," - " %57," - " %58, %59," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), - "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), - "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), - "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) - : "l"(desc_a), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x112x64_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// SPARSE GMMA 64x128x64 TN S32+=S8*U8 -template < - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x128x64_S32S8U8_SS_TN -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[64]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, - uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, - uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, - uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, - uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %68, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n128k64.s32.s8.u8 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63}," - " %64," - " %65," - " %66, %67," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), - "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), - "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), - "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), - "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), - "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) - : "l"(desc_a), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x128x64_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// SPARSE GMMA 64x128x64 TN S32+=S8*U8 -template < - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x128x64_S32S8U8_SS_TN_SATURATE -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[64]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, - uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, - uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, - uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, - uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %68, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n128k64.s32.s8.u8.satfinite " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63}," - " %64," - " %65," - " %66, %67," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), - "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), - "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), - "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), - "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), - "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) - : "l"(desc_a), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x128x64_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x144x64 TN S32+=S8*U8 -template < - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x144x64_S32S8U8_SS_TN -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[72]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, - uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, - uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, - uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, - uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, - uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, - uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %76, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n144k64.s32.s8.u8 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71}," - " %72," - " %73," - " %74, %75," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), - "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), - "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), - "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), - "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), - "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), - "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), - "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71) - : "l"(desc_a), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x144x64_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x144x64 TN S32+=S8*U8 -template < - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x144x64_S32S8U8_SS_TN_SATURATE -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[72]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, - uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, - uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, - uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, - uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, - uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, - uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %76, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n144k64.s32.s8.u8.satfinite " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71}," - " %72," - " %73," - " %74, %75," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), - "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), - "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), - "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), - "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), - "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), - "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), - "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71) - : "l"(desc_a), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x144x64_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x160x64 TN S32+=S8*U8 -template < - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x160x64_S32S8U8_SS_TN -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[80]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, - uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, - uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, - uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, - uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, - uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, - uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, - uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, - uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %84, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n160k64.s32.s8.u8 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79}," - " %80," - " %81," - " %82, %83," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), - "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), - "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), - "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), - "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), - "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), - "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), - "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), - "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), - "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79) - : "l"(desc_a), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x160x64_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x160x64 TN S32+=S8*U8 -template < - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x160x64_S32S8U8_SS_TN_SATURATE -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[80]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, - uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, - uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, - uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, - uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, - uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, - uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, - uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, - uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %84, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n160k64.s32.s8.u8.satfinite " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79}," - " %80," - " %81," - " %82, %83," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), - "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), - "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), - "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), - "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), - "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), - "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), - "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), - "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), - "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79) - : "l"(desc_a), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x160x64_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x176x64 TN S32+=S8*U8 -template < - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x176x64_S32S8U8_SS_TN -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[88]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, - uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, - uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, - uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, - uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, - uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, - uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, - uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, - uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, - uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, - uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %92, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n176k64.s32.s8.u8 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87}," - " %88," - " %89," - " %90, %91," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), - "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), - "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), - "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), - "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), - "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), - "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), - "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), - "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), - "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), - "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), - "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87) - : "l"(desc_a), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x176x64_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x176x64 TN S32+=S8*U8 -template < - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x176x64_S32S8U8_SS_TN_SATURATE -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[88]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, - uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, - uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, - uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, - uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, - uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, - uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, - uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, - uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, - uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, - uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %92, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n176k64.s32.s8.u8.satfinite " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87}," - " %88," - " %89," - " %90, %91," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), - "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), - "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), - "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), - "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), - "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), - "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), - "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), - "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), - "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), - "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), - "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87) - : "l"(desc_a), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x176x64_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// SPARSE GMMA 64x192x64 TN S32+=S8*U8 -template < - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x192x64_S32S8U8_SS_TN -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[96]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, - uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, - uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, - uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, - uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, - uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, - uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, - uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, - uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, - uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, - uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, - uint32_t & d88, uint32_t & d89, uint32_t & d90, uint32_t & d91, - uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %100, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n192k64.s32.s8.u8 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87, " - " %88, %89, %90, %91, %92, %93, %94, %95}," - " %96," - " %97," - " %98, %99," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), - "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), - "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), - "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), - "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), - "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), - "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), - "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), - "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), - "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), - "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), - "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87), - "+r"(d88), "+r"(d89), "+r"(d90), "+r"(d91), - "+r"(d92), "+r"(d93), "+r"(d94), "+r"(d95) - : "l"(desc_a), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x192x64_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// SPARSE GMMA 64x192x64 TN S32+=S8*U8 -template < - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x192x64_S32S8U8_SS_TN_SATURATE -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[96]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, - uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, - uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, - uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, - uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, - uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, - uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, - uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, - uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, - uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, - uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, - uint32_t & d88, uint32_t & d89, uint32_t & d90, uint32_t & d91, - uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %100, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n192k64.s32.s8.u8.satfinite " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87, " - " %88, %89, %90, %91, %92, %93, %94, %95}," - " %96," - " %97," - " %98, %99," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), - "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), - "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), - "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), - "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), - "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), - "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), - "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), - "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), - "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), - "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), - "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87), - "+r"(d88), "+r"(d89), "+r"(d90), "+r"(d91), - "+r"(d92), "+r"(d93), "+r"(d94), "+r"(d95) - : "l"(desc_a), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x192x64_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x208x64 TN S32+=S8*U8 -template < - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x208x64_S32S8U8_SS_TN -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[104]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, - uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, - uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, - uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, - uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, - uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, - uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, - uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, - uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, - uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, - uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, - uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, - uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, - uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, - uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, - uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, - uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, - uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, - uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, - uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, - uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, - uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, - uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, - uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, - uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, - uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %108, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n208k64.s32.s8.u8 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87, " - " %88, %89, %90, %91, %92, %93, %94, %95, " - " %96, %97, %98, %99, %100, %101, %102, %103}," - " %104," - " %105," - " %106, %107," - " p;\n" - "}\n" - : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), - "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), - "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), - "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), - "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), - "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), - "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), - "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), - "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), - "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), - "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), - "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), - "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), - "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), - "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), - "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), - "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), - "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), - "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), - "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), - "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), - "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), - "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), - "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), - "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), - "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103) - : "l"(desc_a), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x208x64_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x208x64 TN S32+=S8*U8 -template < - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x208x64_S32S8U8_SS_TN_SATURATE -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[104]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, - uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, - uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, - uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, - uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, - uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, - uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, - uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, - uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, - uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, - uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, - uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, - uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, - uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, - uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, - uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, - uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, - uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, - uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, - uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, - uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, - uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, - uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, - uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, - uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, - uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %108, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n208k64.s32.s8.u8.satfinite " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87, " - " %88, %89, %90, %91, %92, %93, %94, %95, " - " %96, %97, %98, %99, %100, %101, %102, %103}," - " %104," - " %105," - " %106, %107," - " p;\n" - "}\n" - : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), - "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), - "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), - "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), - "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), - "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), - "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), - "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), - "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), - "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), - "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), - "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), - "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), - "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), - "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), - "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), - "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), - "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), - "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), - "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), - "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), - "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), - "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), - "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), - "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), - "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103) - : "l"(desc_a), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x208x64_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x224x64 TN S32+=S8*U8 -template < - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x224x64_S32S8U8_SS_TN -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[112]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, - uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, - uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, - uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, - uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, - uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, - uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, - uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, - uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, - uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, - uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, - uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, - uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, - uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, - uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, - uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, - uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, - uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, - uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, - uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, - uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, - uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, - uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, - uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, - uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, - uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, - uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, - uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %116, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n224k64.s32.s8.u8 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87, " - " %88, %89, %90, %91, %92, %93, %94, %95, " - " %96, %97, %98, %99, %100, %101, %102, %103, " - " %104, %105, %106, %107, %108, %109, %110, %111}," - " %112," - " %113," - " %114, %115," - " p;\n" - "}\n" - : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), - "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), - "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), - "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), - "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), - "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), - "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), - "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), - "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), - "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), - "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), - "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), - "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), - "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), - "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), - "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), - "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), - "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), - "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), - "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), - "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), - "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), - "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), - "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), - "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), - "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), - "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), - "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111) - : "l"(desc_a), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x224x64_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x224x64 TN S32+=S8*U8 -template < - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x224x64_S32S8U8_SS_TN_SATURATE -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[112]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, - uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, - uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, - uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, - uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, - uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, - uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, - uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, - uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, - uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, - uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, - uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, - uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, - uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, - uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, - uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, - uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, - uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, - uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, - uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, - uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, - uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, - uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, - uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, - uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, - uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, - uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, - uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %116, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n224k64.s32.s8.u8.satfinite " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87, " - " %88, %89, %90, %91, %92, %93, %94, %95, " - " %96, %97, %98, %99, %100, %101, %102, %103, " - " %104, %105, %106, %107, %108, %109, %110, %111}," - " %112," - " %113," - " %114, %115," - " p;\n" - "}\n" - : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), - "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), - "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), - "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), - "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), - "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), - "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), - "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), - "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), - "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), - "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), - "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), - "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), - "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), - "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), - "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), - "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), - "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), - "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), - "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), - "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), - "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), - "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), - "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), - "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), - "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), - "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), - "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111) - : "l"(desc_a), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x224x64_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x240x64 TN S32+=S8*U8 -template < - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x240x64_S32S8U8_SS_TN -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[120]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, - uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, - uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, - uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, - uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, - uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, - uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, - uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, - uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, - uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, - uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, - uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, - uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, - uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, - uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, - uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, - uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, - uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, - uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, - uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, - uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, - uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, - uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, - uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, - uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, - uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, - uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, - uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, - uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, - uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %124, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n240k64.s32.s8.u8 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87, " - " %88, %89, %90, %91, %92, %93, %94, %95, " - " %96, %97, %98, %99, %100, %101, %102, %103, " - " %104, %105, %106, %107, %108, %109, %110, %111, " - " %112, %113, %114, %115, %116, %117, %118, %119}," - " %120," - " %121," - " %122, %123," - " p;\n" - "}\n" - : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), - "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), - "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), - "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), - "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), - "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), - "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), - "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), - "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), - "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), - "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), - "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), - "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), - "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), - "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), - "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), - "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), - "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), - "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), - "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), - "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), - "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), - "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), - "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), - "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), - "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), - "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), - "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), - "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), - "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119) - : "l"(desc_a), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x240x64_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x240x64 TN S32+=S8*U8 -template < - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x240x64_S32S8U8_SS_TN_SATURATE -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[120]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, - uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, - uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, - uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, - uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, - uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, - uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, - uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, - uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, - uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, - uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, - uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, - uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, - uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, - uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, - uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, - uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, - uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, - uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, - uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, - uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, - uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, - uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, - uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, - uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, - uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, - uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, - uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, - uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, - uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %124, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n240k64.s32.s8.u8.satfinite " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87, " - " %88, %89, %90, %91, %92, %93, %94, %95, " - " %96, %97, %98, %99, %100, %101, %102, %103, " - " %104, %105, %106, %107, %108, %109, %110, %111, " - " %112, %113, %114, %115, %116, %117, %118, %119}," - " %120," - " %121," - " %122, %123," - " p;\n" - "}\n" - : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), - "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), - "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), - "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), - "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), - "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), - "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), - "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), - "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), - "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), - "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), - "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), - "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), - "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), - "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), - "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), - "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), - "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), - "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), - "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), - "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), - "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), - "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), - "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), - "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), - "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), - "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), - "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), - "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), - "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119) - : "l"(desc_a), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x240x64_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// SPARSE GMMA 64x256x64 TN S32+=S8*U8 -template < - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x256x64_S32S8U8_SS_TN -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[128]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, - uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, - uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, - uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, - uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, - uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, - uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, - uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, - uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, - uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, - uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, - uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, - uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, - uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, - uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, - uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, - uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, - uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, - uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, - uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, - uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, - uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, - uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, - uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, - uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, - uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, - uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, - uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, - uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, - uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, - uint32_t & d120, uint32_t & d121, uint32_t & d122, uint32_t & d123, - uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %132, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n256k64.s32.s8.u8 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87, " - " %88, %89, %90, %91, %92, %93, %94, %95, " - " %96, %97, %98, %99, %100, %101, %102, %103, " - " %104, %105, %106, %107, %108, %109, %110, %111, " - " %112, %113, %114, %115, %116, %117, %118, %119, " - " %120, %121, %122, %123, %124, %125, %126, %127}," - " %128," - " %129," - " %130, %131," - " p;\n" - "}\n" - : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), - "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), - "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), - "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), - "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), - "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), - "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), - "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), - "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), - "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), - "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), - "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), - "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), - "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), - "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), - "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), - "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), - "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), - "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), - "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), - "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), - "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), - "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), - "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), - "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), - "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), - "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), - "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), - "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), - "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119), - "+r"(d120), "+r"(d121), "+r"(d122), "+r"(d123), - "+r"(d124), "+r"(d125), "+r"(d126), "+r"(d127) - : "l"(desc_a), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x256x64_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// SPARSE GMMA 64x256x64 TN S32+=S8*U8 -template < - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x256x64_S32S8U8_SS_TN_SATURATE -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[128]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, - uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, - uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, - uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, - uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, - uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, - uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, - uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, - uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, - uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, - uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, - uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, - uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, - uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, - uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, - uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, - uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, - uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, - uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, - uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, - uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, - uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, - uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, - uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, - uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, - uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, - uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, - uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, - uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, - uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, - uint32_t & d120, uint32_t & d121, uint32_t & d122, uint32_t & d123, - uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %132, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n256k64.s32.s8.u8.satfinite " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87, " - " %88, %89, %90, %91, %92, %93, %94, %95, " - " %96, %97, %98, %99, %100, %101, %102, %103, " - " %104, %105, %106, %107, %108, %109, %110, %111, " - " %112, %113, %114, %115, %116, %117, %118, %119, " - " %120, %121, %122, %123, %124, %125, %126, %127}," - " %128," - " %129," - " %130, %131," - " p;\n" - "}\n" - : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), - "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), - "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), - "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), - "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), - "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), - "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), - "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), - "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), - "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), - "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), - "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), - "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), - "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), - "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), - "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), - "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), - "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), - "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), - "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), - "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), - "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), - "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), - "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), - "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), - "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), - "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), - "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), - "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), - "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119), - "+r"(d120), "+r"(d121), "+r"(d122), "+r"(d123), - "+r"(d124), "+r"(d125), "+r"(d126), "+r"(d127) - : "l"(desc_a), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x256x64_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// SPARSE GMMA 64x8x64 TN S32+=S8*U8 -template < - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x8x64_S32S8U8_RS_TN -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[4]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, - uint64_t const& desc_b, - uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %11, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n8k64.s32.s8.u8 " - "{%0, %1, %2, %3}," - "{%4, %5, %6, %7}," - " %8," - " %9, %10," - " p;\n" - "}\n" - : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) - : "r"(a0), "r"(a1), "r"(a2), "r"(a3), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x8x64_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// SPARSE GMMA 64x8x64 TN S32+=S8*U8 -template < - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x8x64_S32S8U8_RS_TN_SATURATE -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[4]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, - uint64_t const& desc_b, - uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %11, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n8k64.s32.s8.u8.satfinite " - "{%0, %1, %2, %3}," - "{%4, %5, %6, %7}," - " %8," - " %9, %10," - " p;\n" - "}\n" - : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) - : "r"(a0), "r"(a1), "r"(a2), "r"(a3), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x8x64_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// SPARSE GMMA 64x16x64 TN S32+=S8*U8 -template < - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x16x64_S32S8U8_RS_TN -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[8]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, - uint64_t const& desc_b, - uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, - uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %15, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n16k64.s32.s8.u8 " - "{%0, %1, %2, %3, %4, %5, %6, %7}," - "{%8, %9, %10, %11}," - " %12," - " %13, %14," - " p;\n" - "}\n" - : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), - "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) - : "r"(a0), "r"(a1), "r"(a2), "r"(a3), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x16x64_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// SPARSE GMMA 64x16x64 TN S32+=S8*U8 -template < - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x16x64_S32S8U8_RS_TN_SATURATE -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[8]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, - uint64_t const& desc_b, - uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, - uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %15, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n16k64.s32.s8.u8.satfinite " - "{%0, %1, %2, %3, %4, %5, %6, %7}," - "{%8, %9, %10, %11}," - " %12," - " %13, %14," - " p;\n" - "}\n" - : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), - "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) - : "r"(a0), "r"(a1), "r"(a2), "r"(a3), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x16x64_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// SPARSE GMMA 64x32x64 TN S32+=S8*U8 -template < - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x32x64_S32S8U8_RS_TN -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[16]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %23, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n32k64.s32.s8.u8 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15}," - "{%16, %17, %18, %19}," - " %20," - " %21, %22," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x32x64_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// SPARSE GMMA 64x32x64 TN S32+=S8*U8 -template < - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x32x64_S32S8U8_RS_TN_SATURATE -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[16]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %23, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n32k64.s32.s8.u8.satfinite " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15}," - "{%16, %17, %18, %19}," - " %20," - " %21, %22," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x32x64_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x48x64 TN S32+=S8*U8 -template < - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x48x64_S32S8U8_RS_TN -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[24]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %31, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n48k64.s32.s8.u8 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23}," - "{%24, %25, %26, %27}," - " %28," - " %29, %30," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x48x64_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x48x64 TN S32+=S8*U8 -template < - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x48x64_S32S8U8_RS_TN_SATURATE -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[24]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %31, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n48k64.s32.s8.u8.satfinite " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23}," - "{%24, %25, %26, %27}," - " %28," - " %29, %30," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x48x64_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// SPARSE GMMA 64x64x64 TN S32+=S8*U8 -template < - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x64x64_S32S8U8_RS_TN -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[32]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %39, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n64k64.s32.s8.u8 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31}," - "{%32, %33, %34, %35}," - " %36," - " %37, %38," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x64x64_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// SPARSE GMMA 64x64x64 TN S32+=S8*U8 -template < - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x64x64_S32S8U8_RS_TN_SATURATE -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[32]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %39, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n64k64.s32.s8.u8.satfinite " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31}," - "{%32, %33, %34, %35}," - " %36," - " %37, %38," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x64x64_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x80x64 TN S32+=S8*U8 -template < - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x80x64_S32S8U8_RS_TN -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[40]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %47, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n80k64.s32.s8.u8 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39}," - "{%40, %41, %42, %43}," - " %44," - " %45, %46," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x80x64_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x80x64 TN S32+=S8*U8 -template < - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x80x64_S32S8U8_RS_TN_SATURATE -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[40]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %47, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n80k64.s32.s8.u8.satfinite " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39}," - "{%40, %41, %42, %43}," - " %44," - " %45, %46," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x80x64_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// SPARSE GMMA 64x96x64 TN S32+=S8*U8 -template < - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x96x64_S32S8U8_RS_TN -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[48]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %55, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n96k64.s32.s8.u8 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47}," - "{%48, %49, %50, %51}," - " %52," - " %53, %54," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), - "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x96x64_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// SPARSE GMMA 64x96x64 TN S32+=S8*U8 -template < - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x96x64_S32S8U8_RS_TN_SATURATE -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[48]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %55, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n96k64.s32.s8.u8.satfinite " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47}," - "{%48, %49, %50, %51}," - " %52," - " %53, %54," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), - "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x96x64_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x112x64 TN S32+=S8*U8 -template < - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x112x64_S32S8U8_RS_TN -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[56]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, - uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, - uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %63, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n112k64.s32.s8.u8 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55}," - "{%56, %57, %58, %59}," - " %60," - " %61, %62," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), - "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), - "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), - "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x112x64_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x112x64 TN S32+=S8*U8 -template < - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x112x64_S32S8U8_RS_TN_SATURATE -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[56]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, - uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, - uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %63, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n112k64.s32.s8.u8.satfinite " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55}," - "{%56, %57, %58, %59}," - " %60," - " %61, %62," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), - "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), - "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), - "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x112x64_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// SPARSE GMMA 64x128x64 TN S32+=S8*U8 -template < - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x128x64_S32S8U8_RS_TN -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[64]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, - uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, - uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, - uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, - uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %71, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n128k64.s32.s8.u8 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63}," - "{%64, %65, %66, %67}," - " %68," - " %69, %70," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), - "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), - "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), - "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), - "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), - "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x128x64_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// SPARSE GMMA 64x128x64 TN S32+=S8*U8 -template < - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x128x64_S32S8U8_RS_TN_SATURATE -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[64]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, - uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, - uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, - uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, - uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %71, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n128k64.s32.s8.u8.satfinite " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63}," - "{%64, %65, %66, %67}," - " %68," - " %69, %70," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), - "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), - "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), - "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), - "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), - "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x128x64_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x144x64 TN S32+=S8*U8 -template < - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x144x64_S32S8U8_RS_TN -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[72]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, - uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, - uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, - uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, - uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, - uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, - uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %79, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n144k64.s32.s8.u8 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71}," - "{%72, %73, %74, %75}," - " %76," - " %77, %78," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), - "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), - "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), - "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), - "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), - "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), - "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), - "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x144x64_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x144x64 TN S32+=S8*U8 -template < - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x144x64_S32S8U8_RS_TN_SATURATE -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[72]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, - uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, - uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, - uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, - uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, - uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, - uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %79, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n144k64.s32.s8.u8.satfinite " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71}," - "{%72, %73, %74, %75}," - " %76," - " %77, %78," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), - "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), - "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), - "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), - "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), - "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), - "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), - "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x144x64_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x160x64 TN S32+=S8*U8 -template < - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x160x64_S32S8U8_RS_TN -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[80]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, - uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, - uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, - uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, - uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, - uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, - uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, - uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, - uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %87, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n160k64.s32.s8.u8 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79}," - "{%80, %81, %82, %83}," - " %84," - " %85, %86," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), - "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), - "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), - "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), - "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), - "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), - "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), - "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), - "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), - "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x160x64_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x160x64 TN S32+=S8*U8 -template < - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x160x64_S32S8U8_RS_TN_SATURATE -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[80]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, - uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, - uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, - uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, - uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, - uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, - uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, - uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, - uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %87, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n160k64.s32.s8.u8.satfinite " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79}," - "{%80, %81, %82, %83}," - " %84," - " %85, %86," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), - "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), - "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), - "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), - "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), - "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), - "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), - "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), - "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), - "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x160x64_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x176x64 TN S32+=S8*U8 -template < - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x176x64_S32S8U8_RS_TN -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[88]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, - uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, - uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, - uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, - uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, - uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, - uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, - uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, - uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, - uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, - uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %95, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n176k64.s32.s8.u8 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87}," - "{%88, %89, %90, %91}," - " %92," - " %93, %94," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), - "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), - "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), - "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), - "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), - "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), - "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), - "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), - "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), - "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), - "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), - "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x176x64_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x176x64 TN S32+=S8*U8 -template < - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x176x64_S32S8U8_RS_TN_SATURATE -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[88]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, - uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, - uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, - uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, - uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, - uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, - uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, - uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, - uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, - uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, - uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %95, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n176k64.s32.s8.u8.satfinite " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87}," - "{%88, %89, %90, %91}," - " %92," - " %93, %94," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), - "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), - "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), - "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), - "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), - "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), - "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), - "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), - "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), - "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), - "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), - "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x176x64_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// SPARSE GMMA 64x192x64 TN S32+=S8*U8 -template < - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x192x64_S32S8U8_RS_TN -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[96]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, - uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, - uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, - uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, - uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, - uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, - uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, - uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, - uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, - uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, - uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, - uint32_t & d88, uint32_t & d89, uint32_t & d90, uint32_t & d91, - uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %103, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n192k64.s32.s8.u8 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87, " - " %88, %89, %90, %91, %92, %93, %94, %95}," - "{%96, %97, %98, %99}," - " %100," - " %101, %102," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), - "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), - "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), - "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), - "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), - "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), - "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), - "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), - "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), - "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), - "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), - "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87), - "+r"(d88), "+r"(d89), "+r"(d90), "+r"(d91), - "+r"(d92), "+r"(d93), "+r"(d94), "+r"(d95) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x192x64_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// SPARSE GMMA 64x192x64 TN S32+=S8*U8 -template < - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x192x64_S32S8U8_RS_TN_SATURATE -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[96]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, - uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, - uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, - uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, - uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, - uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, - uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, - uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, - uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, - uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, - uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, - uint32_t & d88, uint32_t & d89, uint32_t & d90, uint32_t & d91, - uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %103, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n192k64.s32.s8.u8.satfinite " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87, " - " %88, %89, %90, %91, %92, %93, %94, %95}," - "{%96, %97, %98, %99}," - " %100," - " %101, %102," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), - "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), - "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), - "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), - "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), - "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), - "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), - "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), - "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), - "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), - "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), - "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87), - "+r"(d88), "+r"(d89), "+r"(d90), "+r"(d91), - "+r"(d92), "+r"(d93), "+r"(d94), "+r"(d95) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x192x64_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x208x64 TN S32+=S8*U8 -template < - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x208x64_S32S8U8_RS_TN -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[104]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, - uint64_t const& desc_b, - uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, - uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, - uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, - uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, - uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, - uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, - uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, - uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, - uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, - uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, - uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, - uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, - uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, - uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, - uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, - uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, - uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, - uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, - uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, - uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, - uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, - uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, - uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, - uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, - uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, - uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %111, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n208k64.s32.s8.u8 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87, " - " %88, %89, %90, %91, %92, %93, %94, %95, " - " %96, %97, %98, %99, %100, %101, %102, %103}," - "{%104, %105, %106, %107}," - " %108," - " %109, %110," - " p;\n" - "}\n" - : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), - "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), - "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), - "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), - "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), - "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), - "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), - "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), - "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), - "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), - "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), - "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), - "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), - "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), - "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), - "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), - "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), - "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), - "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), - "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), - "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), - "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), - "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), - "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), - "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), - "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103) - : "r"(a000), "r"(a001), "r"(a002), "r"(a003), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x208x64_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x208x64 TN S32+=S8*U8 -template < - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x208x64_S32S8U8_RS_TN_SATURATE -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[104]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, - uint64_t const& desc_b, - uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, - uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, - uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, - uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, - uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, - uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, - uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, - uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, - uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, - uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, - uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, - uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, - uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, - uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, - uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, - uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, - uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, - uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, - uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, - uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, - uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, - uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, - uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, - uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, - uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, - uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %111, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n208k64.s32.s8.u8.satfinite " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87, " - " %88, %89, %90, %91, %92, %93, %94, %95, " - " %96, %97, %98, %99, %100, %101, %102, %103}," - "{%104, %105, %106, %107}," - " %108," - " %109, %110," - " p;\n" - "}\n" - : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), - "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), - "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), - "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), - "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), - "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), - "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), - "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), - "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), - "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), - "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), - "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), - "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), - "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), - "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), - "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), - "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), - "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), - "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), - "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), - "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), - "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), - "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), - "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), - "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), - "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103) - : "r"(a000), "r"(a001), "r"(a002), "r"(a003), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x208x64_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x224x64 TN S32+=S8*U8 -template < - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x224x64_S32S8U8_RS_TN -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[112]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, - uint64_t const& desc_b, - uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, - uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, - uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, - uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, - uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, - uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, - uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, - uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, - uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, - uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, - uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, - uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, - uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, - uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, - uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, - uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, - uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, - uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, - uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, - uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, - uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, - uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, - uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, - uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, - uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, - uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, - uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, - uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %119, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n224k64.s32.s8.u8 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87, " - " %88, %89, %90, %91, %92, %93, %94, %95, " - " %96, %97, %98, %99, %100, %101, %102, %103, " - " %104, %105, %106, %107, %108, %109, %110, %111}," - "{%112, %113, %114, %115}," - " %116," - " %117, %118," - " p;\n" - "}\n" - : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), - "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), - "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), - "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), - "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), - "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), - "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), - "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), - "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), - "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), - "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), - "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), - "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), - "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), - "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), - "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), - "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), - "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), - "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), - "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), - "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), - "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), - "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), - "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), - "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), - "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), - "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), - "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111) - : "r"(a000), "r"(a001), "r"(a002), "r"(a003), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x224x64_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x224x64 TN S32+=S8*U8 -template < - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x224x64_S32S8U8_RS_TN_SATURATE -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[112]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, - uint64_t const& desc_b, - uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, - uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, - uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, - uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, - uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, - uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, - uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, - uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, - uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, - uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, - uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, - uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, - uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, - uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, - uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, - uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, - uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, - uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, - uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, - uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, - uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, - uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, - uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, - uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, - uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, - uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, - uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, - uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %119, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n224k64.s32.s8.u8.satfinite " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87, " - " %88, %89, %90, %91, %92, %93, %94, %95, " - " %96, %97, %98, %99, %100, %101, %102, %103, " - " %104, %105, %106, %107, %108, %109, %110, %111}," - "{%112, %113, %114, %115}," - " %116," - " %117, %118," - " p;\n" - "}\n" - : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), - "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), - "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), - "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), - "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), - "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), - "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), - "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), - "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), - "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), - "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), - "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), - "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), - "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), - "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), - "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), - "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), - "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), - "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), - "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), - "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), - "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), - "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), - "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), - "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), - "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), - "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), - "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111) - : "r"(a000), "r"(a001), "r"(a002), "r"(a003), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x224x64_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x240x64 TN S32+=S8*U8 -template < - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x240x64_S32S8U8_RS_TN -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[120]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, - uint64_t const& desc_b, - uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, - uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, - uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, - uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, - uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, - uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, - uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, - uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, - uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, - uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, - uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, - uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, - uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, - uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, - uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, - uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, - uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, - uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, - uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, - uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, - uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, - uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, - uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, - uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, - uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, - uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, - uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, - uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, - uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, - uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %127, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n240k64.s32.s8.u8 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87, " - " %88, %89, %90, %91, %92, %93, %94, %95, " - " %96, %97, %98, %99, %100, %101, %102, %103, " - " %104, %105, %106, %107, %108, %109, %110, %111, " - " %112, %113, %114, %115, %116, %117, %118, %119}," - "{%120, %121, %122, %123}," - " %124," - " %125, %126," - " p;\n" - "}\n" - : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), - "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), - "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), - "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), - "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), - "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), - "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), - "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), - "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), - "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), - "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), - "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), - "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), - "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), - "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), - "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), - "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), - "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), - "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), - "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), - "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), - "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), - "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), - "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), - "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), - "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), - "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), - "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), - "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), - "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119) - : "r"(a000), "r"(a001), "r"(a002), "r"(a003), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x240x64_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x240x64 TN S32+=S8*U8 -template < - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x240x64_S32S8U8_RS_TN_SATURATE -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[120]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, - uint64_t const& desc_b, - uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, - uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, - uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, - uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, - uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, - uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, - uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, - uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, - uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, - uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, - uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, - uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, - uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, - uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, - uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, - uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, - uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, - uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, - uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, - uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, - uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, - uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, - uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, - uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, - uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, - uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, - uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, - uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, - uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, - uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %127, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n240k64.s32.s8.u8.satfinite " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87, " - " %88, %89, %90, %91, %92, %93, %94, %95, " - " %96, %97, %98, %99, %100, %101, %102, %103, " - " %104, %105, %106, %107, %108, %109, %110, %111, " - " %112, %113, %114, %115, %116, %117, %118, %119}," - "{%120, %121, %122, %123}," - " %124," - " %125, %126," - " p;\n" - "}\n" - : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), - "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), - "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), - "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), - "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), - "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), - "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), - "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), - "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), - "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), - "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), - "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), - "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), - "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), - "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), - "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), - "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), - "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), - "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), - "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), - "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), - "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), - "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), - "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), - "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), - "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), - "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), - "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), - "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), - "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119) - : "r"(a000), "r"(a001), "r"(a002), "r"(a003), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x240x64_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// SPARSE GMMA 64x256x64 TN S32+=S8*U8 -template < - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x256x64_S32S8U8_RS_TN -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[128]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, - uint64_t const& desc_b, - uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, - uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, - uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, - uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, - uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, - uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, - uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, - uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, - uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, - uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, - uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, - uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, - uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, - uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, - uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, - uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, - uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, - uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, - uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, - uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, - uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, - uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, - uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, - uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, - uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, - uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, - uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, - uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, - uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, - uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, - uint32_t & d120, uint32_t & d121, uint32_t & d122, uint32_t & d123, - uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %135, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n256k64.s32.s8.u8 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87, " - " %88, %89, %90, %91, %92, %93, %94, %95, " - " %96, %97, %98, %99, %100, %101, %102, %103, " - " %104, %105, %106, %107, %108, %109, %110, %111, " - " %112, %113, %114, %115, %116, %117, %118, %119, " - " %120, %121, %122, %123, %124, %125, %126, %127}," - "{%128, %129, %130, %131}," - " %132," - " %133, %134," - " p;\n" - "}\n" - : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), - "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), - "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), - "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), - "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), - "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), - "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), - "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), - "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), - "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), - "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), - "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), - "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), - "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), - "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), - "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), - "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), - "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), - "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), - "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), - "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), - "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), - "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), - "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), - "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), - "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), - "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), - "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), - "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), - "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119), - "+r"(d120), "+r"(d121), "+r"(d122), "+r"(d123), - "+r"(d124), "+r"(d125), "+r"(d126), "+r"(d127) - : "r"(a000), "r"(a001), "r"(a002), "r"(a003), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x256x64_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// SPARSE GMMA 64x256x64 TN S32+=S8*U8 -template < - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x256x64_S32S8U8_RS_TN_SATURATE -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[128]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, - uint64_t const& desc_b, - uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, - uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, - uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, - uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, - uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, - uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, - uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, - uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, - uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, - uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, - uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, - uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, - uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, - uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, - uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, - uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, - uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, - uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, - uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, - uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, - uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, - uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, - uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, - uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, - uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, - uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, - uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, - uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, - uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, - uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, - uint32_t & d120, uint32_t & d121, uint32_t & d122, uint32_t & d123, - uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %135, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n256k64.s32.s8.u8.satfinite " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87, " - " %88, %89, %90, %91, %92, %93, %94, %95, " - " %96, %97, %98, %99, %100, %101, %102, %103, " - " %104, %105, %106, %107, %108, %109, %110, %111, " - " %112, %113, %114, %115, %116, %117, %118, %119, " - " %120, %121, %122, %123, %124, %125, %126, %127}," - "{%128, %129, %130, %131}," - " %132," - " %133, %134," - " p;\n" - "}\n" - : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), - "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), - "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), - "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), - "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), - "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), - "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), - "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), - "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), - "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), - "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), - "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), - "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), - "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), - "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), - "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), - "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), - "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), - "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), - "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), - "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), - "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), - "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), - "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), - "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), - "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), - "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), - "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), - "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), - "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119), - "+r"(d120), "+r"(d121), "+r"(d122), "+r"(d123), - "+r"(d124), "+r"(d125), "+r"(d126), "+r"(d127) - : "r"(a000), "r"(a001), "r"(a002), "r"(a003), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x256x64_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// SPARSE GMMA 64x8x64 TN S32+=U8*S8 -template < - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x8x64_S32U8S8_SS_TN -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[4]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %8, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n8k64.s32.u8.s8 " - "{%0, %1, %2, %3}," - " %4," - " %5," - " %6, %7," - " p;\n" - "}\n" - : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) - : "l"(desc_a), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x8x64_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// SPARSE GMMA 64x8x64 TN S32+=U8*S8 -template < - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x8x64_S32U8S8_SS_TN_SATURATE -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[4]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %8, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n8k64.s32.u8.s8.satfinite " - "{%0, %1, %2, %3}," - " %4," - " %5," - " %6, %7," - " p;\n" - "}\n" - : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) - : "l"(desc_a), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x8x64_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// SPARSE GMMA 64x16x64 TN S32+=U8*S8 -template < - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x16x64_S32U8S8_SS_TN -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[8]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, - uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %12, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n16k64.s32.u8.s8 " - "{%0, %1, %2, %3, %4, %5, %6, %7}," - " %8," - " %9," - " %10, %11," - " p;\n" - "}\n" - : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), - "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) - : "l"(desc_a), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x16x64_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// SPARSE GMMA 64x16x64 TN S32+=U8*S8 -template < - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x16x64_S32U8S8_SS_TN_SATURATE -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[8]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, - uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %12, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n16k64.s32.u8.s8.satfinite " - "{%0, %1, %2, %3, %4, %5, %6, %7}," - " %8," - " %9," - " %10, %11," - " p;\n" - "}\n" - : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), - "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) - : "l"(desc_a), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x16x64_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// SPARSE GMMA 64x32x64 TN S32+=U8*S8 -template < - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x32x64_S32U8S8_SS_TN -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[16]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %20, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n32k64.s32.u8.s8 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15}," - " %16," - " %17," - " %18, %19," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) - : "l"(desc_a), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x32x64_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// SPARSE GMMA 64x32x64 TN S32+=U8*S8 -template < - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x32x64_S32U8S8_SS_TN_SATURATE -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[16]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %20, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n32k64.s32.u8.s8.satfinite " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15}," - " %16," - " %17," - " %18, %19," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) - : "l"(desc_a), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x32x64_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x48x64 TN S32+=U8*S8 -template < - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x48x64_S32U8S8_SS_TN -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[24]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %28, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n48k64.s32.u8.s8 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23}," - " %24," - " %25," - " %26, %27," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) - : "l"(desc_a), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x48x64_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x48x64 TN S32+=U8*S8 -template < - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x48x64_S32U8S8_SS_TN_SATURATE -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[24]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %28, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n48k64.s32.u8.s8.satfinite " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23}," - " %24," - " %25," - " %26, %27," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) - : "l"(desc_a), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x48x64_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// SPARSE GMMA 64x64x64 TN S32+=U8*S8 -template < - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x64x64_S32U8S8_SS_TN -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[32]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %36, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n64k64.s32.u8.s8 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31}," - " %32," - " %33," - " %34, %35," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) - : "l"(desc_a), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x64x64_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// SPARSE GMMA 64x64x64 TN S32+=U8*S8 -template < - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x64x64_S32U8S8_SS_TN_SATURATE -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[32]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %36, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n64k64.s32.u8.s8.satfinite " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31}," - " %32," - " %33," - " %34, %35," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) - : "l"(desc_a), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x64x64_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x80x64 TN S32+=U8*S8 -template < - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x80x64_S32U8S8_SS_TN -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[40]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %44, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n80k64.s32.u8.s8 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39}," - " %40," - " %41," - " %42, %43," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) - : "l"(desc_a), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x80x64_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x80x64 TN S32+=U8*S8 -template < - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x80x64_S32U8S8_SS_TN_SATURATE -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[40]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %44, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n80k64.s32.u8.s8.satfinite " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39}," - " %40," - " %41," - " %42, %43," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) - : "l"(desc_a), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x80x64_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// SPARSE GMMA 64x96x64 TN S32+=U8*S8 -template < - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x96x64_S32U8S8_SS_TN -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[48]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %52, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n96k64.s32.u8.s8 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47}," - " %48," - " %49," - " %50, %51," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), - "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) - : "l"(desc_a), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x96x64_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// SPARSE GMMA 64x96x64 TN S32+=U8*S8 -template < - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x96x64_S32U8S8_SS_TN_SATURATE -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[48]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %52, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n96k64.s32.u8.s8.satfinite " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47}," - " %48," - " %49," - " %50, %51," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), - "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) - : "l"(desc_a), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x96x64_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x112x64 TN S32+=U8*S8 -template < - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x112x64_S32U8S8_SS_TN -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[56]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, - uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, - uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %60, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n112k64.s32.u8.s8 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55}," - " %56," - " %57," - " %58, %59," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), - "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), - "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), - "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) - : "l"(desc_a), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x112x64_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x112x64 TN S32+=U8*S8 -template < - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x112x64_S32U8S8_SS_TN_SATURATE -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[56]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, - uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, - uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %60, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n112k64.s32.u8.s8.satfinite " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55}," - " %56," - " %57," - " %58, %59," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), - "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), - "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), - "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) - : "l"(desc_a), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x112x64_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// SPARSE GMMA 64x128x64 TN S32+=U8*S8 -template < - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x128x64_S32U8S8_SS_TN -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[64]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, - uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, - uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, - uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, - uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %68, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n128k64.s32.u8.s8 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63}," - " %64," - " %65," - " %66, %67," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), - "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), - "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), - "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), - "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), - "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) - : "l"(desc_a), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x128x64_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// SPARSE GMMA 64x128x64 TN S32+=U8*S8 -template < - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x128x64_S32U8S8_SS_TN_SATURATE -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[64]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, - uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, - uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, - uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, - uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %68, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n128k64.s32.u8.s8.satfinite " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63}," - " %64," - " %65," - " %66, %67," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), - "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), - "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), - "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), - "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), - "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) - : "l"(desc_a), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x128x64_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x144x64 TN S32+=U8*S8 -template < - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x144x64_S32U8S8_SS_TN -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[72]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, - uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, - uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, - uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, - uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, - uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, - uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %76, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n144k64.s32.u8.s8 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71}," - " %72," - " %73," - " %74, %75," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), - "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), - "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), - "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), - "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), - "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), - "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), - "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71) - : "l"(desc_a), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x144x64_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x144x64 TN S32+=U8*S8 -template < - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x144x64_S32U8S8_SS_TN_SATURATE -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[72]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, - uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, - uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, - uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, - uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, - uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, - uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %76, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n144k64.s32.u8.s8.satfinite " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71}," - " %72," - " %73," - " %74, %75," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), - "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), - "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), - "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), - "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), - "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), - "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), - "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71) - : "l"(desc_a), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x144x64_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x160x64 TN S32+=U8*S8 -template < - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x160x64_S32U8S8_SS_TN -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[80]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, - uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, - uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, - uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, - uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, - uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, - uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, - uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, - uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %84, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n160k64.s32.u8.s8 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79}," - " %80," - " %81," - " %82, %83," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), - "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), - "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), - "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), - "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), - "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), - "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), - "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), - "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), - "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79) - : "l"(desc_a), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x160x64_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x160x64 TN S32+=U8*S8 -template < - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x160x64_S32U8S8_SS_TN_SATURATE -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[80]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, - uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, - uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, - uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, - uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, - uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, - uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, - uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, - uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %84, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n160k64.s32.u8.s8.satfinite " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79}," - " %80," - " %81," - " %82, %83," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), - "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), - "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), - "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), - "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), - "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), - "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), - "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), - "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), - "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79) - : "l"(desc_a), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x160x64_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x176x64 TN S32+=U8*S8 -template < - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x176x64_S32U8S8_SS_TN -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[88]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, - uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, - uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, - uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, - uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, - uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, - uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, - uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, - uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, - uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, - uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %92, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n176k64.s32.u8.s8 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87}," - " %88," - " %89," - " %90, %91," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), - "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), - "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), - "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), - "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), - "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), - "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), - "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), - "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), - "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), - "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), - "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87) - : "l"(desc_a), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x176x64_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x176x64 TN S32+=U8*S8 -template < - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x176x64_S32U8S8_SS_TN_SATURATE -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[88]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, - uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, - uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, - uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, - uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, - uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, - uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, - uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, - uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, - uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, - uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %92, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n176k64.s32.u8.s8.satfinite " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87}," - " %88," - " %89," - " %90, %91," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), - "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), - "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), - "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), - "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), - "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), - "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), - "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), - "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), - "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), - "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), - "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87) - : "l"(desc_a), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x176x64_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// SPARSE GMMA 64x192x64 TN S32+=U8*S8 -template < - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x192x64_S32U8S8_SS_TN -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[96]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, - uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, - uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, - uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, - uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, - uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, - uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, - uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, - uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, - uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, - uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, - uint32_t & d88, uint32_t & d89, uint32_t & d90, uint32_t & d91, - uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %100, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n192k64.s32.u8.s8 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87, " - " %88, %89, %90, %91, %92, %93, %94, %95}," - " %96," - " %97," - " %98, %99," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), - "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), - "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), - "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), - "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), - "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), - "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), - "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), - "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), - "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), - "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), - "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87), - "+r"(d88), "+r"(d89), "+r"(d90), "+r"(d91), - "+r"(d92), "+r"(d93), "+r"(d94), "+r"(d95) - : "l"(desc_a), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x192x64_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// SPARSE GMMA 64x192x64 TN S32+=U8*S8 -template < - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x192x64_S32U8S8_SS_TN_SATURATE -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[96]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, - uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, - uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, - uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, - uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, - uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, - uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, - uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, - uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, - uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, - uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, - uint32_t & d88, uint32_t & d89, uint32_t & d90, uint32_t & d91, - uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %100, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n192k64.s32.u8.s8.satfinite " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87, " - " %88, %89, %90, %91, %92, %93, %94, %95}," - " %96," - " %97," - " %98, %99," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), - "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), - "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), - "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), - "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), - "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), - "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), - "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), - "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), - "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), - "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), - "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87), - "+r"(d88), "+r"(d89), "+r"(d90), "+r"(d91), - "+r"(d92), "+r"(d93), "+r"(d94), "+r"(d95) - : "l"(desc_a), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x192x64_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x208x64 TN S32+=U8*S8 -template < - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x208x64_S32U8S8_SS_TN -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[104]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, - uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, - uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, - uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, - uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, - uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, - uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, - uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, - uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, - uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, - uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, - uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, - uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, - uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, - uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, - uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, - uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, - uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, - uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, - uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, - uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, - uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, - uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, - uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, - uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, - uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %108, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n208k64.s32.u8.s8 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87, " - " %88, %89, %90, %91, %92, %93, %94, %95, " - " %96, %97, %98, %99, %100, %101, %102, %103}," - " %104," - " %105," - " %106, %107," - " p;\n" - "}\n" - : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), - "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), - "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), - "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), - "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), - "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), - "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), - "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), - "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), - "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), - "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), - "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), - "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), - "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), - "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), - "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), - "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), - "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), - "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), - "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), - "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), - "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), - "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), - "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), - "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), - "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103) - : "l"(desc_a), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x208x64_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x208x64 TN S32+=U8*S8 -template < - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x208x64_S32U8S8_SS_TN_SATURATE -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[104]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, - uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, - uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, - uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, - uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, - uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, - uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, - uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, - uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, - uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, - uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, - uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, - uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, - uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, - uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, - uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, - uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, - uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, - uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, - uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, - uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, - uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, - uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, - uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, - uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, - uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %108, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n208k64.s32.u8.s8.satfinite " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87, " - " %88, %89, %90, %91, %92, %93, %94, %95, " - " %96, %97, %98, %99, %100, %101, %102, %103}," - " %104," - " %105," - " %106, %107," - " p;\n" - "}\n" - : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), - "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), - "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), - "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), - "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), - "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), - "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), - "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), - "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), - "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), - "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), - "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), - "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), - "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), - "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), - "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), - "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), - "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), - "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), - "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), - "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), - "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), - "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), - "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), - "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), - "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103) - : "l"(desc_a), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x208x64_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x224x64 TN S32+=U8*S8 -template < - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x224x64_S32U8S8_SS_TN -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[112]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, - uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, - uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, - uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, - uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, - uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, - uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, - uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, - uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, - uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, - uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, - uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, - uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, - uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, - uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, - uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, - uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, - uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, - uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, - uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, - uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, - uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, - uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, - uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, - uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, - uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, - uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, - uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %116, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n224k64.s32.u8.s8 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87, " - " %88, %89, %90, %91, %92, %93, %94, %95, " - " %96, %97, %98, %99, %100, %101, %102, %103, " - " %104, %105, %106, %107, %108, %109, %110, %111}," - " %112," - " %113," - " %114, %115," - " p;\n" - "}\n" - : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), - "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), - "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), - "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), - "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), - "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), - "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), - "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), - "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), - "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), - "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), - "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), - "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), - "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), - "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), - "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), - "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), - "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), - "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), - "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), - "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), - "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), - "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), - "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), - "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), - "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), - "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), - "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111) - : "l"(desc_a), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x224x64_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x224x64 TN S32+=U8*S8 -template < - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x224x64_S32U8S8_SS_TN_SATURATE -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[112]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, - uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, - uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, - uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, - uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, - uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, - uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, - uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, - uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, - uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, - uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, - uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, - uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, - uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, - uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, - uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, - uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, - uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, - uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, - uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, - uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, - uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, - uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, - uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, - uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, - uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, - uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, - uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %116, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n224k64.s32.u8.s8.satfinite " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87, " - " %88, %89, %90, %91, %92, %93, %94, %95, " - " %96, %97, %98, %99, %100, %101, %102, %103, " - " %104, %105, %106, %107, %108, %109, %110, %111}," - " %112," - " %113," - " %114, %115," - " p;\n" - "}\n" - : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), - "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), - "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), - "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), - "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), - "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), - "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), - "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), - "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), - "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), - "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), - "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), - "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), - "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), - "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), - "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), - "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), - "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), - "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), - "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), - "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), - "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), - "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), - "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), - "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), - "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), - "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), - "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111) - : "l"(desc_a), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x224x64_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x240x64 TN S32+=U8*S8 -template < - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x240x64_S32U8S8_SS_TN -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[120]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, - uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, - uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, - uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, - uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, - uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, - uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, - uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, - uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, - uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, - uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, - uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, - uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, - uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, - uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, - uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, - uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, - uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, - uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, - uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, - uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, - uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, - uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, - uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, - uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, - uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, - uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, - uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, - uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, - uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %124, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n240k64.s32.u8.s8 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87, " - " %88, %89, %90, %91, %92, %93, %94, %95, " - " %96, %97, %98, %99, %100, %101, %102, %103, " - " %104, %105, %106, %107, %108, %109, %110, %111, " - " %112, %113, %114, %115, %116, %117, %118, %119}," - " %120," - " %121," - " %122, %123," - " p;\n" - "}\n" - : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), - "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), - "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), - "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), - "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), - "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), - "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), - "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), - "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), - "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), - "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), - "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), - "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), - "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), - "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), - "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), - "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), - "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), - "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), - "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), - "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), - "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), - "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), - "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), - "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), - "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), - "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), - "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), - "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), - "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119) - : "l"(desc_a), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x240x64_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x240x64 TN S32+=U8*S8 -template < - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x240x64_S32U8S8_SS_TN_SATURATE -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[120]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, - uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, - uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, - uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, - uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, - uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, - uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, - uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, - uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, - uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, - uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, - uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, - uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, - uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, - uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, - uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, - uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, - uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, - uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, - uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, - uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, - uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, - uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, - uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, - uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, - uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, - uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, - uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, - uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, - uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %124, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n240k64.s32.u8.s8.satfinite " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87, " - " %88, %89, %90, %91, %92, %93, %94, %95, " - " %96, %97, %98, %99, %100, %101, %102, %103, " - " %104, %105, %106, %107, %108, %109, %110, %111, " - " %112, %113, %114, %115, %116, %117, %118, %119}," - " %120," - " %121," - " %122, %123," - " p;\n" - "}\n" - : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), - "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), - "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), - "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), - "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), - "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), - "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), - "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), - "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), - "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), - "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), - "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), - "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), - "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), - "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), - "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), - "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), - "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), - "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), - "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), - "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), - "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), - "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), - "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), - "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), - "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), - "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), - "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), - "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), - "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119) - : "l"(desc_a), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x240x64_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// SPARSE GMMA 64x256x64 TN S32+=U8*S8 -template < - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x256x64_S32U8S8_SS_TN -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[128]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, - uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, - uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, - uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, - uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, - uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, - uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, - uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, - uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, - uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, - uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, - uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, - uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, - uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, - uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, - uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, - uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, - uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, - uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, - uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, - uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, - uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, - uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, - uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, - uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, - uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, - uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, - uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, - uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, - uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, - uint32_t & d120, uint32_t & d121, uint32_t & d122, uint32_t & d123, - uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %132, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n256k64.s32.u8.s8 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87, " - " %88, %89, %90, %91, %92, %93, %94, %95, " - " %96, %97, %98, %99, %100, %101, %102, %103, " - " %104, %105, %106, %107, %108, %109, %110, %111, " - " %112, %113, %114, %115, %116, %117, %118, %119, " - " %120, %121, %122, %123, %124, %125, %126, %127}," - " %128," - " %129," - " %130, %131," - " p;\n" - "}\n" - : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), - "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), - "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), - "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), - "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), - "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), - "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), - "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), - "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), - "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), - "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), - "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), - "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), - "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), - "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), - "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), - "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), - "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), - "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), - "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), - "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), - "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), - "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), - "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), - "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), - "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), - "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), - "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), - "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), - "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119), - "+r"(d120), "+r"(d121), "+r"(d122), "+r"(d123), - "+r"(d124), "+r"(d125), "+r"(d126), "+r"(d127) - : "l"(desc_a), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x256x64_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// SPARSE GMMA 64x256x64 TN S32+=U8*S8 -template < - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x256x64_S32U8S8_SS_TN_SATURATE -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[128]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, - uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, - uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, - uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, - uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, - uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, - uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, - uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, - uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, - uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, - uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, - uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, - uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, - uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, - uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, - uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, - uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, - uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, - uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, - uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, - uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, - uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, - uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, - uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, - uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, - uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, - uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, - uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, - uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, - uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, - uint32_t & d120, uint32_t & d121, uint32_t & d122, uint32_t & d123, - uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %132, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n256k64.s32.u8.s8.satfinite " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87, " - " %88, %89, %90, %91, %92, %93, %94, %95, " - " %96, %97, %98, %99, %100, %101, %102, %103, " - " %104, %105, %106, %107, %108, %109, %110, %111, " - " %112, %113, %114, %115, %116, %117, %118, %119, " - " %120, %121, %122, %123, %124, %125, %126, %127}," - " %128," - " %129," - " %130, %131," - " p;\n" - "}\n" - : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), - "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), - "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), - "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), - "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), - "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), - "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), - "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), - "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), - "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), - "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), - "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), - "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), - "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), - "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), - "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), - "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), - "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), - "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), - "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), - "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), - "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), - "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), - "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), - "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), - "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), - "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), - "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), - "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), - "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119), - "+r"(d120), "+r"(d121), "+r"(d122), "+r"(d123), - "+r"(d124), "+r"(d125), "+r"(d126), "+r"(d127) - : "l"(desc_a), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x256x64_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// SPARSE GMMA 64x8x64 TN S32+=U8*S8 -template < - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x8x64_S32U8S8_RS_TN -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[4]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, - uint64_t const& desc_b, - uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %11, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n8k64.s32.u8.s8 " - "{%0, %1, %2, %3}," - "{%4, %5, %6, %7}," - " %8," - " %9, %10," - " p;\n" - "}\n" - : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) - : "r"(a0), "r"(a1), "r"(a2), "r"(a3), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x8x64_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// SPARSE GMMA 64x8x64 TN S32+=U8*S8 -template < - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x8x64_S32U8S8_RS_TN_SATURATE -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[4]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, - uint64_t const& desc_b, - uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %11, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n8k64.s32.u8.s8.satfinite " - "{%0, %1, %2, %3}," - "{%4, %5, %6, %7}," - " %8," - " %9, %10," - " p;\n" - "}\n" - : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) - : "r"(a0), "r"(a1), "r"(a2), "r"(a3), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x8x64_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// SPARSE GMMA 64x16x64 TN S32+=U8*S8 -template < - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x16x64_S32U8S8_RS_TN -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[8]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, - uint64_t const& desc_b, - uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, - uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %15, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n16k64.s32.u8.s8 " - "{%0, %1, %2, %3, %4, %5, %6, %7}," - "{%8, %9, %10, %11}," - " %12," - " %13, %14," - " p;\n" - "}\n" - : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), - "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) - : "r"(a0), "r"(a1), "r"(a2), "r"(a3), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x16x64_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// SPARSE GMMA 64x16x64 TN S32+=U8*S8 -template < - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x16x64_S32U8S8_RS_TN_SATURATE -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[8]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, - uint64_t const& desc_b, - uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, - uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %15, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n16k64.s32.u8.s8.satfinite " - "{%0, %1, %2, %3, %4, %5, %6, %7}," - "{%8, %9, %10, %11}," - " %12," - " %13, %14," - " p;\n" - "}\n" - : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), - "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) - : "r"(a0), "r"(a1), "r"(a2), "r"(a3), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x16x64_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// SPARSE GMMA 64x32x64 TN S32+=U8*S8 -template < - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x32x64_S32U8S8_RS_TN -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[16]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %23, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n32k64.s32.u8.s8 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15}," - "{%16, %17, %18, %19}," - " %20," - " %21, %22," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x32x64_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// SPARSE GMMA 64x32x64 TN S32+=U8*S8 -template < - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x32x64_S32U8S8_RS_TN_SATURATE -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[16]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %23, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n32k64.s32.u8.s8.satfinite " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15}," - "{%16, %17, %18, %19}," - " %20," - " %21, %22," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x32x64_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x48x64 TN S32+=U8*S8 -template < - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x48x64_S32U8S8_RS_TN -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[24]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %31, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n48k64.s32.u8.s8 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23}," - "{%24, %25, %26, %27}," - " %28," - " %29, %30," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x48x64_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x48x64 TN S32+=U8*S8 -template < - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x48x64_S32U8S8_RS_TN_SATURATE -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[24]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %31, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n48k64.s32.u8.s8.satfinite " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23}," - "{%24, %25, %26, %27}," - " %28," - " %29, %30," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x48x64_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// SPARSE GMMA 64x64x64 TN S32+=U8*S8 -template < - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x64x64_S32U8S8_RS_TN -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[32]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %39, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n64k64.s32.u8.s8 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31}," - "{%32, %33, %34, %35}," - " %36," - " %37, %38," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x64x64_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// SPARSE GMMA 64x64x64 TN S32+=U8*S8 -template < - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x64x64_S32U8S8_RS_TN_SATURATE -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[32]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %39, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n64k64.s32.u8.s8.satfinite " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31}," - "{%32, %33, %34, %35}," - " %36," - " %37, %38," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x64x64_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x80x64 TN S32+=U8*S8 -template < - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x80x64_S32U8S8_RS_TN -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[40]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %47, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n80k64.s32.u8.s8 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39}," - "{%40, %41, %42, %43}," - " %44," - " %45, %46," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x80x64_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x80x64 TN S32+=U8*S8 -template < - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x80x64_S32U8S8_RS_TN_SATURATE -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[40]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %47, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n80k64.s32.u8.s8.satfinite " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39}," - "{%40, %41, %42, %43}," - " %44," - " %45, %46," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x80x64_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// SPARSE GMMA 64x96x64 TN S32+=U8*S8 -template < - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x96x64_S32U8S8_RS_TN -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[48]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %55, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n96k64.s32.u8.s8 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47}," - "{%48, %49, %50, %51}," - " %52," - " %53, %54," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), - "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x96x64_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// SPARSE GMMA 64x96x64 TN S32+=U8*S8 -template < - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x96x64_S32U8S8_RS_TN_SATURATE -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[48]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %55, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n96k64.s32.u8.s8.satfinite " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47}," - "{%48, %49, %50, %51}," - " %52," - " %53, %54," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), - "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x96x64_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x112x64 TN S32+=U8*S8 -template < - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x112x64_S32U8S8_RS_TN -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[56]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, - uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, - uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %63, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n112k64.s32.u8.s8 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55}," - "{%56, %57, %58, %59}," - " %60," - " %61, %62," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), - "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), - "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), - "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x112x64_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x112x64 TN S32+=U8*S8 -template < - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x112x64_S32U8S8_RS_TN_SATURATE -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[56]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, - uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, - uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %63, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n112k64.s32.u8.s8.satfinite " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55}," - "{%56, %57, %58, %59}," - " %60," - " %61, %62," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), - "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), - "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), - "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x112x64_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// SPARSE GMMA 64x128x64 TN S32+=U8*S8 -template < - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x128x64_S32U8S8_RS_TN -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[64]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, - uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, - uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, - uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, - uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %71, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n128k64.s32.u8.s8 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63}," - "{%64, %65, %66, %67}," - " %68," - " %69, %70," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), - "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), - "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), - "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), - "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), - "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x128x64_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// SPARSE GMMA 64x128x64 TN S32+=U8*S8 -template < - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x128x64_S32U8S8_RS_TN_SATURATE -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[64]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, - uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, - uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, - uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, - uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %71, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n128k64.s32.u8.s8.satfinite " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63}," - "{%64, %65, %66, %67}," - " %68," - " %69, %70," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), - "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), - "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), - "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), - "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), - "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x128x64_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x144x64 TN S32+=U8*S8 -template < - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x144x64_S32U8S8_RS_TN -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[72]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, - uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, - uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, - uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, - uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, - uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, - uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %79, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n144k64.s32.u8.s8 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71}," - "{%72, %73, %74, %75}," - " %76," - " %77, %78," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), - "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), - "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), - "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), - "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), - "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), - "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), - "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x144x64_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x144x64 TN S32+=U8*S8 -template < - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x144x64_S32U8S8_RS_TN_SATURATE -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[72]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, - uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, - uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, - uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, - uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, - uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, - uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %79, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n144k64.s32.u8.s8.satfinite " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71}," - "{%72, %73, %74, %75}," - " %76," - " %77, %78," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), - "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), - "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), - "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), - "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), - "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), - "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), - "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x144x64_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x160x64 TN S32+=U8*S8 -template < - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x160x64_S32U8S8_RS_TN -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[80]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, - uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, - uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, - uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, - uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, - uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, - uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, - uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, - uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %87, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n160k64.s32.u8.s8 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79}," - "{%80, %81, %82, %83}," - " %84," - " %85, %86," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), - "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), - "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), - "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), - "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), - "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), - "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), - "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), - "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), - "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x160x64_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x160x64 TN S32+=U8*S8 -template < - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x160x64_S32U8S8_RS_TN_SATURATE -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[80]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, - uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, - uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, - uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, - uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, - uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, - uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, - uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, - uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %87, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n160k64.s32.u8.s8.satfinite " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79}," - "{%80, %81, %82, %83}," - " %84," - " %85, %86," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), - "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), - "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), - "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), - "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), - "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), - "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), - "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), - "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), - "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x160x64_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x176x64 TN S32+=U8*S8 -template < - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x176x64_S32U8S8_RS_TN -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[88]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, - uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, - uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, - uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, - uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, - uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, - uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, - uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, - uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, - uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, - uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %95, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n176k64.s32.u8.s8 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87}," - "{%88, %89, %90, %91}," - " %92," - " %93, %94," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), - "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), - "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), - "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), - "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), - "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), - "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), - "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), - "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), - "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), - "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), - "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x176x64_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x176x64 TN S32+=U8*S8 -template < - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x176x64_S32U8S8_RS_TN_SATURATE -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[88]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, - uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, - uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, - uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, - uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, - uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, - uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, - uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, - uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, - uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, - uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %95, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n176k64.s32.u8.s8.satfinite " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87}," - "{%88, %89, %90, %91}," - " %92," - " %93, %94," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), - "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), - "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), - "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), - "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), - "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), - "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), - "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), - "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), - "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), - "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), - "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x176x64_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// SPARSE GMMA 64x192x64 TN S32+=U8*S8 -template < - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x192x64_S32U8S8_RS_TN -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[96]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, - uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, - uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, - uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, - uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, - uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, - uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, - uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, - uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, - uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, - uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, - uint32_t & d88, uint32_t & d89, uint32_t & d90, uint32_t & d91, - uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %103, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n192k64.s32.u8.s8 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87, " - " %88, %89, %90, %91, %92, %93, %94, %95}," - "{%96, %97, %98, %99}," - " %100," - " %101, %102," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), - "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), - "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), - "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), - "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), - "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), - "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), - "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), - "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), - "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), - "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), - "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87), - "+r"(d88), "+r"(d89), "+r"(d90), "+r"(d91), - "+r"(d92), "+r"(d93), "+r"(d94), "+r"(d95) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x192x64_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// SPARSE GMMA 64x192x64 TN S32+=U8*S8 -template < - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x192x64_S32U8S8_RS_TN_SATURATE -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[96]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, - uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, - uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, - uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, - uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, - uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, - uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, - uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, - uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, - uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, - uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, - uint32_t & d88, uint32_t & d89, uint32_t & d90, uint32_t & d91, - uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %103, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n192k64.s32.u8.s8.satfinite " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87, " - " %88, %89, %90, %91, %92, %93, %94, %95}," - "{%96, %97, %98, %99}," - " %100," - " %101, %102," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), - "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), - "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), - "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), - "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), - "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), - "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), - "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), - "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), - "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), - "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), - "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87), - "+r"(d88), "+r"(d89), "+r"(d90), "+r"(d91), - "+r"(d92), "+r"(d93), "+r"(d94), "+r"(d95) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x192x64_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x208x64 TN S32+=U8*S8 -template < - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x208x64_S32U8S8_RS_TN -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[104]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, - uint64_t const& desc_b, - uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, - uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, - uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, - uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, - uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, - uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, - uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, - uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, - uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, - uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, - uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, - uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, - uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, - uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, - uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, - uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, - uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, - uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, - uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, - uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, - uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, - uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, - uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, - uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, - uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, - uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %111, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n208k64.s32.u8.s8 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87, " - " %88, %89, %90, %91, %92, %93, %94, %95, " - " %96, %97, %98, %99, %100, %101, %102, %103}," - "{%104, %105, %106, %107}," - " %108," - " %109, %110," - " p;\n" - "}\n" - : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), - "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), - "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), - "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), - "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), - "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), - "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), - "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), - "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), - "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), - "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), - "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), - "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), - "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), - "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), - "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), - "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), - "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), - "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), - "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), - "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), - "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), - "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), - "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), - "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), - "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103) - : "r"(a000), "r"(a001), "r"(a002), "r"(a003), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x208x64_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x208x64 TN S32+=U8*S8 -template < - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x208x64_S32U8S8_RS_TN_SATURATE -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[104]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, - uint64_t const& desc_b, - uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, - uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, - uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, - uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, - uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, - uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, - uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, - uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, - uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, - uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, - uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, - uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, - uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, - uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, - uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, - uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, - uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, - uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, - uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, - uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, - uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, - uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, - uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, - uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, - uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, - uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %111, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n208k64.s32.u8.s8.satfinite " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87, " - " %88, %89, %90, %91, %92, %93, %94, %95, " - " %96, %97, %98, %99, %100, %101, %102, %103}," - "{%104, %105, %106, %107}," - " %108," - " %109, %110," - " p;\n" - "}\n" - : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), - "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), - "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), - "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), - "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), - "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), - "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), - "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), - "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), - "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), - "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), - "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), - "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), - "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), - "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), - "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), - "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), - "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), - "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), - "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), - "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), - "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), - "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), - "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), - "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), - "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103) - : "r"(a000), "r"(a001), "r"(a002), "r"(a003), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x208x64_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x224x64 TN S32+=U8*S8 -template < - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x224x64_S32U8S8_RS_TN -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[112]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, - uint64_t const& desc_b, - uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, - uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, - uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, - uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, - uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, - uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, - uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, - uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, - uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, - uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, - uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, - uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, - uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, - uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, - uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, - uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, - uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, - uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, - uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, - uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, - uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, - uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, - uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, - uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, - uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, - uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, - uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, - uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %119, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n224k64.s32.u8.s8 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87, " - " %88, %89, %90, %91, %92, %93, %94, %95, " - " %96, %97, %98, %99, %100, %101, %102, %103, " - " %104, %105, %106, %107, %108, %109, %110, %111}," - "{%112, %113, %114, %115}," - " %116," - " %117, %118," - " p;\n" - "}\n" - : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), - "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), - "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), - "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), - "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), - "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), - "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), - "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), - "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), - "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), - "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), - "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), - "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), - "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), - "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), - "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), - "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), - "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), - "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), - "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), - "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), - "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), - "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), - "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), - "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), - "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), - "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), - "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111) - : "r"(a000), "r"(a001), "r"(a002), "r"(a003), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x224x64_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x224x64 TN S32+=U8*S8 -template < - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x224x64_S32U8S8_RS_TN_SATURATE -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[112]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, - uint64_t const& desc_b, - uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, - uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, - uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, - uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, - uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, - uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, - uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, - uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, - uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, - uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, - uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, - uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, - uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, - uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, - uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, - uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, - uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, - uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, - uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, - uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, - uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, - uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, - uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, - uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, - uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, - uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, - uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, - uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %119, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n224k64.s32.u8.s8.satfinite " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87, " - " %88, %89, %90, %91, %92, %93, %94, %95, " - " %96, %97, %98, %99, %100, %101, %102, %103, " - " %104, %105, %106, %107, %108, %109, %110, %111}," - "{%112, %113, %114, %115}," - " %116," - " %117, %118," - " p;\n" - "}\n" - : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), - "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), - "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), - "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), - "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), - "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), - "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), - "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), - "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), - "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), - "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), - "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), - "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), - "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), - "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), - "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), - "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), - "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), - "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), - "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), - "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), - "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), - "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), - "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), - "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), - "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), - "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), - "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111) - : "r"(a000), "r"(a001), "r"(a002), "r"(a003), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x224x64_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x240x64 TN S32+=U8*S8 -template < - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x240x64_S32U8S8_RS_TN -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[120]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, - uint64_t const& desc_b, - uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, - uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, - uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, - uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, - uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, - uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, - uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, - uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, - uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, - uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, - uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, - uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, - uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, - uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, - uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, - uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, - uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, - uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, - uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, - uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, - uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, - uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, - uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, - uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, - uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, - uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, - uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, - uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, - uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, - uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %127, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n240k64.s32.u8.s8 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87, " - " %88, %89, %90, %91, %92, %93, %94, %95, " - " %96, %97, %98, %99, %100, %101, %102, %103, " - " %104, %105, %106, %107, %108, %109, %110, %111, " - " %112, %113, %114, %115, %116, %117, %118, %119}," - "{%120, %121, %122, %123}," - " %124," - " %125, %126," - " p;\n" - "}\n" - : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), - "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), - "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), - "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), - "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), - "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), - "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), - "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), - "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), - "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), - "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), - "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), - "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), - "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), - "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), - "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), - "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), - "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), - "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), - "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), - "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), - "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), - "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), - "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), - "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), - "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), - "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), - "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), - "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), - "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119) - : "r"(a000), "r"(a001), "r"(a002), "r"(a003), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x240x64_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x240x64 TN S32+=U8*S8 -template < - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x240x64_S32U8S8_RS_TN_SATURATE -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[120]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, - uint64_t const& desc_b, - uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, - uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, - uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, - uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, - uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, - uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, - uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, - uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, - uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, - uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, - uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, - uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, - uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, - uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, - uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, - uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, - uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, - uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, - uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, - uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, - uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, - uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, - uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, - uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, - uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, - uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, - uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, - uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, - uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, - uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %127, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n240k64.s32.u8.s8.satfinite " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87, " - " %88, %89, %90, %91, %92, %93, %94, %95, " - " %96, %97, %98, %99, %100, %101, %102, %103, " - " %104, %105, %106, %107, %108, %109, %110, %111, " - " %112, %113, %114, %115, %116, %117, %118, %119}," - "{%120, %121, %122, %123}," - " %124," - " %125, %126," - " p;\n" - "}\n" - : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), - "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), - "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), - "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), - "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), - "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), - "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), - "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), - "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), - "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), - "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), - "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), - "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), - "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), - "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), - "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), - "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), - "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), - "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), - "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), - "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), - "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), - "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), - "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), - "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), - "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), - "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), - "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), - "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), - "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119) - : "r"(a000), "r"(a001), "r"(a002), "r"(a003), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x240x64_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// SPARSE GMMA 64x256x64 TN S32+=U8*S8 -template < - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x256x64_S32U8S8_RS_TN -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[128]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, - uint64_t const& desc_b, - uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, - uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, - uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, - uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, - uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, - uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, - uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, - uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, - uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, - uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, - uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, - uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, - uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, - uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, - uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, - uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, - uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, - uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, - uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, - uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, - uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, - uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, - uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, - uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, - uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, - uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, - uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, - uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, - uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, - uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, - uint32_t & d120, uint32_t & d121, uint32_t & d122, uint32_t & d123, - uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %135, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n256k64.s32.u8.s8 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87, " - " %88, %89, %90, %91, %92, %93, %94, %95, " - " %96, %97, %98, %99, %100, %101, %102, %103, " - " %104, %105, %106, %107, %108, %109, %110, %111, " - " %112, %113, %114, %115, %116, %117, %118, %119, " - " %120, %121, %122, %123, %124, %125, %126, %127}," - "{%128, %129, %130, %131}," - " %132," - " %133, %134," - " p;\n" - "}\n" - : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), - "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), - "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), - "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), - "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), - "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), - "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), - "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), - "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), - "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), - "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), - "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), - "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), - "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), - "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), - "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), - "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), - "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), - "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), - "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), - "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), - "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), - "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), - "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), - "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), - "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), - "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), - "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), - "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), - "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119), - "+r"(d120), "+r"(d121), "+r"(d122), "+r"(d123), - "+r"(d124), "+r"(d125), "+r"(d126), "+r"(d127) - : "r"(a000), "r"(a001), "r"(a002), "r"(a003), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x256x64_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// SPARSE GMMA 64x256x64 TN S32+=U8*S8 -template < - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x256x64_S32U8S8_RS_TN_SATURATE -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[128]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, - uint64_t const& desc_b, - uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, - uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, - uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, - uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, - uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, - uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, - uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, - uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, - uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, - uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, - uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, - uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, - uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, - uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, - uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, - uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, - uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, - uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, - uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, - uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, - uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, - uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, - uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, - uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, - uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, - uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, - uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, - uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, - uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, - uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, - uint32_t & d120, uint32_t & d121, uint32_t & d122, uint32_t & d123, - uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %135, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n256k64.s32.u8.s8.satfinite " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87, " - " %88, %89, %90, %91, %92, %93, %94, %95, " - " %96, %97, %98, %99, %100, %101, %102, %103, " - " %104, %105, %106, %107, %108, %109, %110, %111, " - " %112, %113, %114, %115, %116, %117, %118, %119, " - " %120, %121, %122, %123, %124, %125, %126, %127}," - "{%128, %129, %130, %131}," - " %132," - " %133, %134," - " p;\n" - "}\n" - : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), - "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), - "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), - "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), - "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), - "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), - "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), - "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), - "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), - "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), - "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), - "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), - "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), - "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), - "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), - "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), - "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), - "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), - "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), - "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), - "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), - "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), - "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), - "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), - "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), - "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), - "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), - "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), - "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), - "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119), - "+r"(d120), "+r"(d121), "+r"(d122), "+r"(d123), - "+r"(d124), "+r"(d125), "+r"(d126), "+r"(d127) - : "r"(a000), "r"(a001), "r"(a002), "r"(a003), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x256x64_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// SPARSE GMMA 64x8x64 TN S32+=U8*U8 -template < - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x8x64_S32U8U8_SS_TN -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[4]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %8, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n8k64.s32.u8.u8 " - "{%0, %1, %2, %3}," - " %4," - " %5," - " %6, %7," - " p;\n" - "}\n" - : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) - : "l"(desc_a), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x8x64_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// SPARSE GMMA 64x8x64 TN S32+=U8*U8 -template < - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x8x64_S32U8U8_SS_TN_SATURATE -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[4]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %8, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n8k64.s32.u8.u8.satfinite " - "{%0, %1, %2, %3}," - " %4," - " %5," - " %6, %7," - " p;\n" - "}\n" - : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) - : "l"(desc_a), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x8x64_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// SPARSE GMMA 64x16x64 TN S32+=U8*U8 -template < - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x16x64_S32U8U8_SS_TN -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[8]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, - uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %12, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n16k64.s32.u8.u8 " - "{%0, %1, %2, %3, %4, %5, %6, %7}," - " %8," - " %9," - " %10, %11," - " p;\n" - "}\n" - : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), - "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) - : "l"(desc_a), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x16x64_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// SPARSE GMMA 64x16x64 TN S32+=U8*U8 -template < - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x16x64_S32U8U8_SS_TN_SATURATE -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[8]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, - uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %12, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n16k64.s32.u8.u8.satfinite " - "{%0, %1, %2, %3, %4, %5, %6, %7}," - " %8," - " %9," - " %10, %11," - " p;\n" - "}\n" - : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), - "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) - : "l"(desc_a), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x16x64_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// SPARSE GMMA 64x32x64 TN S32+=U8*U8 -template < - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x32x64_S32U8U8_SS_TN -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[16]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %20, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n32k64.s32.u8.u8 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15}," - " %16," - " %17," - " %18, %19," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) - : "l"(desc_a), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x32x64_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// SPARSE GMMA 64x32x64 TN S32+=U8*U8 -template < - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x32x64_S32U8U8_SS_TN_SATURATE -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[16]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %20, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n32k64.s32.u8.u8.satfinite " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15}," - " %16," - " %17," - " %18, %19," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) - : "l"(desc_a), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x32x64_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x48x64 TN S32+=U8*U8 -template < - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x48x64_S32U8U8_SS_TN -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[24]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %28, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n48k64.s32.u8.u8 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23}," - " %24," - " %25," - " %26, %27," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) - : "l"(desc_a), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x48x64_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x48x64 TN S32+=U8*U8 -template < - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x48x64_S32U8U8_SS_TN_SATURATE -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[24]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %28, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n48k64.s32.u8.u8.satfinite " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23}," - " %24," - " %25," - " %26, %27," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) - : "l"(desc_a), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x48x64_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// SPARSE GMMA 64x64x64 TN S32+=U8*U8 -template < - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x64x64_S32U8U8_SS_TN -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[32]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %36, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n64k64.s32.u8.u8 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31}," - " %32," - " %33," - " %34, %35," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) - : "l"(desc_a), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x64x64_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// SPARSE GMMA 64x64x64 TN S32+=U8*U8 -template < - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x64x64_S32U8U8_SS_TN_SATURATE -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[32]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %36, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n64k64.s32.u8.u8.satfinite " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31}," - " %32," - " %33," - " %34, %35," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) - : "l"(desc_a), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x64x64_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x80x64 TN S32+=U8*U8 -template < - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x80x64_S32U8U8_SS_TN -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[40]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %44, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n80k64.s32.u8.u8 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39}," - " %40," - " %41," - " %42, %43," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) - : "l"(desc_a), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x80x64_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x80x64 TN S32+=U8*U8 -template < - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x80x64_S32U8U8_SS_TN_SATURATE -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[40]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %44, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n80k64.s32.u8.u8.satfinite " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39}," - " %40," - " %41," - " %42, %43," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) - : "l"(desc_a), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x80x64_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// SPARSE GMMA 64x96x64 TN S32+=U8*U8 -template < - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x96x64_S32U8U8_SS_TN -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[48]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %52, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n96k64.s32.u8.u8 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47}," - " %48," - " %49," - " %50, %51," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), - "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) - : "l"(desc_a), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x96x64_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// SPARSE GMMA 64x96x64 TN S32+=U8*U8 -template < - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x96x64_S32U8U8_SS_TN_SATURATE -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[48]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %52, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n96k64.s32.u8.u8.satfinite " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47}," - " %48," - " %49," - " %50, %51," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), - "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) - : "l"(desc_a), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x96x64_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x112x64 TN S32+=U8*U8 -template < - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x112x64_S32U8U8_SS_TN -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[56]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, - uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, - uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %60, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n112k64.s32.u8.u8 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55}," - " %56," - " %57," - " %58, %59," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), - "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), - "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), - "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) - : "l"(desc_a), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x112x64_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x112x64 TN S32+=U8*U8 -template < - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x112x64_S32U8U8_SS_TN_SATURATE -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[56]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, - uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, - uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %60, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n112k64.s32.u8.u8.satfinite " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55}," - " %56," - " %57," - " %58, %59," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), - "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), - "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), - "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) - : "l"(desc_a), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x112x64_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// SPARSE GMMA 64x128x64 TN S32+=U8*U8 -template < - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x128x64_S32U8U8_SS_TN -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[64]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, - uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, - uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, - uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, - uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %68, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n128k64.s32.u8.u8 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63}," - " %64," - " %65," - " %66, %67," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), - "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), - "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), - "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), - "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), - "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) - : "l"(desc_a), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x128x64_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// SPARSE GMMA 64x128x64 TN S32+=U8*U8 -template < - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x128x64_S32U8U8_SS_TN_SATURATE -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[64]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, - uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, - uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, - uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, - uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %68, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n128k64.s32.u8.u8.satfinite " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63}," - " %64," - " %65," - " %66, %67," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), - "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), - "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), - "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), - "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), - "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) - : "l"(desc_a), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x128x64_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x144x64 TN S32+=U8*U8 -template < - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x144x64_S32U8U8_SS_TN -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[72]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, - uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, - uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, - uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, - uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, - uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, - uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %76, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n144k64.s32.u8.u8 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71}," - " %72," - " %73," - " %74, %75," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), - "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), - "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), - "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), - "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), - "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), - "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), - "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71) - : "l"(desc_a), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x144x64_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x144x64 TN S32+=U8*U8 -template < - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x144x64_S32U8U8_SS_TN_SATURATE -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[72]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, - uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, - uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, - uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, - uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, - uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, - uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %76, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n144k64.s32.u8.u8.satfinite " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71}," - " %72," - " %73," - " %74, %75," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), - "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), - "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), - "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), - "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), - "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), - "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), - "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71) - : "l"(desc_a), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x144x64_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x160x64 TN S32+=U8*U8 -template < - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x160x64_S32U8U8_SS_TN -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[80]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, - uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, - uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, - uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, - uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, - uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, - uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, - uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, - uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %84, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n160k64.s32.u8.u8 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79}," - " %80," - " %81," - " %82, %83," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), - "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), - "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), - "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), - "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), - "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), - "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), - "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), - "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), - "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79) - : "l"(desc_a), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x160x64_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x160x64 TN S32+=U8*U8 -template < - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x160x64_S32U8U8_SS_TN_SATURATE -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[80]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, - uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, - uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, - uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, - uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, - uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, - uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, - uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, - uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %84, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n160k64.s32.u8.u8.satfinite " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79}," - " %80," - " %81," - " %82, %83," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), - "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), - "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), - "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), - "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), - "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), - "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), - "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), - "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), - "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79) - : "l"(desc_a), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x160x64_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x176x64 TN S32+=U8*U8 -template < - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x176x64_S32U8U8_SS_TN -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[88]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, - uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, - uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, - uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, - uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, - uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, - uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, - uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, - uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, - uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, - uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %92, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n176k64.s32.u8.u8 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87}," - " %88," - " %89," - " %90, %91," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), - "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), - "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), - "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), - "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), - "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), - "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), - "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), - "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), - "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), - "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), - "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87) - : "l"(desc_a), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x176x64_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x176x64 TN S32+=U8*U8 -template < - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x176x64_S32U8U8_SS_TN_SATURATE -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[88]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, - uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, - uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, - uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, - uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, - uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, - uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, - uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, - uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, - uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, - uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %92, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n176k64.s32.u8.u8.satfinite " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87}," - " %88," - " %89," - " %90, %91," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), - "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), - "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), - "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), - "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), - "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), - "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), - "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), - "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), - "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), - "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), - "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87) - : "l"(desc_a), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x176x64_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// SPARSE GMMA 64x192x64 TN S32+=U8*U8 -template < - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x192x64_S32U8U8_SS_TN -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[96]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, - uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, - uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, - uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, - uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, - uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, - uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, - uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, - uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, - uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, - uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, - uint32_t & d88, uint32_t & d89, uint32_t & d90, uint32_t & d91, - uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %100, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n192k64.s32.u8.u8 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87, " - " %88, %89, %90, %91, %92, %93, %94, %95}," - " %96," - " %97," - " %98, %99," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), - "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), - "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), - "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), - "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), - "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), - "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), - "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), - "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), - "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), - "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), - "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87), - "+r"(d88), "+r"(d89), "+r"(d90), "+r"(d91), - "+r"(d92), "+r"(d93), "+r"(d94), "+r"(d95) - : "l"(desc_a), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x192x64_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// SPARSE GMMA 64x192x64 TN S32+=U8*U8 -template < - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x192x64_S32U8U8_SS_TN_SATURATE -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[96]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, - uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, - uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, - uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, - uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, - uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, - uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, - uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, - uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, - uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, - uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, - uint32_t & d88, uint32_t & d89, uint32_t & d90, uint32_t & d91, - uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %100, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n192k64.s32.u8.u8.satfinite " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87, " - " %88, %89, %90, %91, %92, %93, %94, %95}," - " %96," - " %97," - " %98, %99," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), - "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), - "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), - "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), - "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), - "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), - "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), - "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), - "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), - "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), - "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), - "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87), - "+r"(d88), "+r"(d89), "+r"(d90), "+r"(d91), - "+r"(d92), "+r"(d93), "+r"(d94), "+r"(d95) - : "l"(desc_a), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x192x64_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x208x64 TN S32+=U8*U8 -template < - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x208x64_S32U8U8_SS_TN -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[104]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, - uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, - uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, - uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, - uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, - uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, - uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, - uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, - uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, - uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, - uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, - uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, - uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, - uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, - uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, - uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, - uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, - uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, - uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, - uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, - uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, - uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, - uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, - uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, - uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, - uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %108, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n208k64.s32.u8.u8 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87, " - " %88, %89, %90, %91, %92, %93, %94, %95, " - " %96, %97, %98, %99, %100, %101, %102, %103}," - " %104," - " %105," - " %106, %107," - " p;\n" - "}\n" - : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), - "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), - "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), - "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), - "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), - "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), - "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), - "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), - "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), - "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), - "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), - "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), - "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), - "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), - "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), - "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), - "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), - "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), - "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), - "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), - "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), - "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), - "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), - "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), - "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), - "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103) - : "l"(desc_a), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x208x64_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x208x64 TN S32+=U8*U8 -template < - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x208x64_S32U8U8_SS_TN_SATURATE -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[104]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, - uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, - uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, - uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, - uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, - uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, - uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, - uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, - uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, - uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, - uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, - uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, - uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, - uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, - uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, - uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, - uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, - uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, - uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, - uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, - uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, - uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, - uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, - uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, - uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, - uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %108, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n208k64.s32.u8.u8.satfinite " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87, " - " %88, %89, %90, %91, %92, %93, %94, %95, " - " %96, %97, %98, %99, %100, %101, %102, %103}," - " %104," - " %105," - " %106, %107," - " p;\n" - "}\n" - : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), - "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), - "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), - "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), - "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), - "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), - "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), - "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), - "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), - "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), - "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), - "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), - "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), - "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), - "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), - "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), - "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), - "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), - "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), - "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), - "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), - "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), - "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), - "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), - "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), - "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103) - : "l"(desc_a), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x208x64_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x224x64 TN S32+=U8*U8 -template < - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x224x64_S32U8U8_SS_TN -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[112]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, - uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, - uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, - uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, - uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, - uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, - uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, - uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, - uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, - uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, - uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, - uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, - uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, - uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, - uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, - uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, - uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, - uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, - uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, - uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, - uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, - uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, - uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, - uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, - uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, - uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, - uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, - uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %116, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n224k64.s32.u8.u8 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87, " - " %88, %89, %90, %91, %92, %93, %94, %95, " - " %96, %97, %98, %99, %100, %101, %102, %103, " - " %104, %105, %106, %107, %108, %109, %110, %111}," - " %112," - " %113," - " %114, %115," - " p;\n" - "}\n" - : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), - "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), - "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), - "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), - "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), - "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), - "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), - "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), - "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), - "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), - "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), - "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), - "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), - "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), - "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), - "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), - "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), - "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), - "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), - "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), - "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), - "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), - "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), - "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), - "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), - "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), - "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), - "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111) - : "l"(desc_a), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x224x64_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x224x64 TN S32+=U8*U8 -template < - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x224x64_S32U8U8_SS_TN_SATURATE -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[112]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, - uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, - uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, - uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, - uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, - uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, - uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, - uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, - uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, - uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, - uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, - uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, - uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, - uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, - uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, - uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, - uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, - uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, - uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, - uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, - uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, - uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, - uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, - uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, - uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, - uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, - uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, - uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %116, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n224k64.s32.u8.u8.satfinite " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87, " - " %88, %89, %90, %91, %92, %93, %94, %95, " - " %96, %97, %98, %99, %100, %101, %102, %103, " - " %104, %105, %106, %107, %108, %109, %110, %111}," - " %112," - " %113," - " %114, %115," - " p;\n" - "}\n" - : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), - "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), - "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), - "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), - "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), - "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), - "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), - "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), - "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), - "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), - "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), - "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), - "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), - "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), - "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), - "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), - "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), - "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), - "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), - "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), - "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), - "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), - "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), - "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), - "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), - "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), - "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), - "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111) - : "l"(desc_a), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x224x64_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x240x64 TN S32+=U8*U8 -template < - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x240x64_S32U8U8_SS_TN -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[120]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, - uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, - uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, - uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, - uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, - uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, - uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, - uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, - uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, - uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, - uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, - uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, - uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, - uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, - uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, - uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, - uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, - uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, - uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, - uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, - uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, - uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, - uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, - uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, - uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, - uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, - uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, - uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, - uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, - uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %124, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n240k64.s32.u8.u8 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87, " - " %88, %89, %90, %91, %92, %93, %94, %95, " - " %96, %97, %98, %99, %100, %101, %102, %103, " - " %104, %105, %106, %107, %108, %109, %110, %111, " - " %112, %113, %114, %115, %116, %117, %118, %119}," - " %120," - " %121," - " %122, %123," - " p;\n" - "}\n" - : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), - "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), - "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), - "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), - "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), - "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), - "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), - "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), - "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), - "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), - "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), - "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), - "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), - "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), - "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), - "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), - "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), - "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), - "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), - "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), - "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), - "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), - "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), - "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), - "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), - "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), - "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), - "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), - "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), - "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119) - : "l"(desc_a), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x240x64_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x240x64 TN S32+=U8*U8 -template < - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x240x64_S32U8U8_SS_TN_SATURATE -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[120]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, - uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, - uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, - uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, - uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, - uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, - uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, - uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, - uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, - uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, - uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, - uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, - uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, - uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, - uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, - uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, - uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, - uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, - uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, - uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, - uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, - uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, - uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, - uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, - uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, - uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, - uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, - uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, - uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, - uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %124, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n240k64.s32.u8.u8.satfinite " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87, " - " %88, %89, %90, %91, %92, %93, %94, %95, " - " %96, %97, %98, %99, %100, %101, %102, %103, " - " %104, %105, %106, %107, %108, %109, %110, %111, " - " %112, %113, %114, %115, %116, %117, %118, %119}," - " %120," - " %121," - " %122, %123," - " p;\n" - "}\n" - : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), - "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), - "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), - "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), - "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), - "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), - "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), - "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), - "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), - "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), - "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), - "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), - "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), - "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), - "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), - "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), - "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), - "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), - "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), - "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), - "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), - "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), - "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), - "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), - "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), - "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), - "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), - "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), - "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), - "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119) - : "l"(desc_a), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x240x64_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// SPARSE GMMA 64x256x64 TN S32+=U8*U8 -template < - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x256x64_S32U8U8_SS_TN -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[128]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, - uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, - uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, - uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, - uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, - uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, - uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, - uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, - uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, - uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, - uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, - uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, - uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, - uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, - uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, - uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, - uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, - uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, - uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, - uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, - uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, - uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, - uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, - uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, - uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, - uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, - uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, - uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, - uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, - uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, - uint32_t & d120, uint32_t & d121, uint32_t & d122, uint32_t & d123, - uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %132, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n256k64.s32.u8.u8 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87, " - " %88, %89, %90, %91, %92, %93, %94, %95, " - " %96, %97, %98, %99, %100, %101, %102, %103, " - " %104, %105, %106, %107, %108, %109, %110, %111, " - " %112, %113, %114, %115, %116, %117, %118, %119, " - " %120, %121, %122, %123, %124, %125, %126, %127}," - " %128," - " %129," - " %130, %131," - " p;\n" - "}\n" - : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), - "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), - "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), - "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), - "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), - "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), - "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), - "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), - "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), - "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), - "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), - "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), - "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), - "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), - "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), - "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), - "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), - "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), - "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), - "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), - "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), - "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), - "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), - "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), - "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), - "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), - "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), - "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), - "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), - "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119), - "+r"(d120), "+r"(d121), "+r"(d122), "+r"(d123), - "+r"(d124), "+r"(d125), "+r"(d126), "+r"(d127) - : "l"(desc_a), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x256x64_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// SPARSE GMMA 64x256x64 TN S32+=U8*U8 -template < - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x256x64_S32U8U8_SS_TN_SATURATE -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[128]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, - uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, - uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, - uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, - uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, - uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, - uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, - uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, - uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, - uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, - uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, - uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, - uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, - uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, - uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, - uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, - uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, - uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, - uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, - uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, - uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, - uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, - uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, - uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, - uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, - uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, - uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, - uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, - uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, - uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, - uint32_t & d120, uint32_t & d121, uint32_t & d122, uint32_t & d123, - uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %132, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n256k64.s32.u8.u8.satfinite " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87, " - " %88, %89, %90, %91, %92, %93, %94, %95, " - " %96, %97, %98, %99, %100, %101, %102, %103, " - " %104, %105, %106, %107, %108, %109, %110, %111, " - " %112, %113, %114, %115, %116, %117, %118, %119, " - " %120, %121, %122, %123, %124, %125, %126, %127}," - " %128," - " %129," - " %130, %131," - " p;\n" - "}\n" - : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), - "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), - "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), - "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), - "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), - "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), - "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), - "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), - "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), - "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), - "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), - "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), - "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), - "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), - "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), - "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), - "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), - "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), - "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), - "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), - "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), - "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), - "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), - "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), - "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), - "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), - "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), - "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), - "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), - "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119), - "+r"(d120), "+r"(d121), "+r"(d122), "+r"(d123), - "+r"(d124), "+r"(d125), "+r"(d126), "+r"(d127) - : "l"(desc_a), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x256x64_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// SPARSE GMMA 64x8x64 TN S32+=U8*U8 -template < - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x8x64_S32U8U8_RS_TN -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[4]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, - uint64_t const& desc_b, - uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %11, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n8k64.s32.u8.u8 " - "{%0, %1, %2, %3}," - "{%4, %5, %6, %7}," - " %8," - " %9, %10," - " p;\n" - "}\n" - : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) - : "r"(a0), "r"(a1), "r"(a2), "r"(a3), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x8x64_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// SPARSE GMMA 64x8x64 TN S32+=U8*U8 -template < - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x8x64_S32U8U8_RS_TN_SATURATE -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[4]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, - uint64_t const& desc_b, - uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %11, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n8k64.s32.u8.u8.satfinite " - "{%0, %1, %2, %3}," - "{%4, %5, %6, %7}," - " %8," - " %9, %10," - " p;\n" - "}\n" - : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) - : "r"(a0), "r"(a1), "r"(a2), "r"(a3), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x8x64_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// SPARSE GMMA 64x16x64 TN S32+=U8*U8 -template < - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x16x64_S32U8U8_RS_TN -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[8]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, - uint64_t const& desc_b, - uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, - uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %15, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n16k64.s32.u8.u8 " - "{%0, %1, %2, %3, %4, %5, %6, %7}," - "{%8, %9, %10, %11}," - " %12," - " %13, %14," - " p;\n" - "}\n" - : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), - "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) - : "r"(a0), "r"(a1), "r"(a2), "r"(a3), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x16x64_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// SPARSE GMMA 64x16x64 TN S32+=U8*U8 -template < - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x16x64_S32U8U8_RS_TN_SATURATE -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[8]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, - uint64_t const& desc_b, - uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, - uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %15, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n16k64.s32.u8.u8.satfinite " - "{%0, %1, %2, %3, %4, %5, %6, %7}," - "{%8, %9, %10, %11}," - " %12," - " %13, %14," - " p;\n" - "}\n" - : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), - "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) - : "r"(a0), "r"(a1), "r"(a2), "r"(a3), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x16x64_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// SPARSE GMMA 64x32x64 TN S32+=U8*U8 -template < - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x32x64_S32U8U8_RS_TN -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[16]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %23, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n32k64.s32.u8.u8 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15}," - "{%16, %17, %18, %19}," - " %20," - " %21, %22," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x32x64_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// SPARSE GMMA 64x32x64 TN S32+=U8*U8 -template < - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x32x64_S32U8U8_RS_TN_SATURATE -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[16]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %23, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n32k64.s32.u8.u8.satfinite " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15}," - "{%16, %17, %18, %19}," - " %20," - " %21, %22," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x32x64_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x48x64 TN S32+=U8*U8 -template < - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x48x64_S32U8U8_RS_TN -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[24]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %31, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n48k64.s32.u8.u8 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23}," - "{%24, %25, %26, %27}," - " %28," - " %29, %30," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x48x64_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x48x64 TN S32+=U8*U8 -template < - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x48x64_S32U8U8_RS_TN_SATURATE -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[24]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %31, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n48k64.s32.u8.u8.satfinite " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23}," - "{%24, %25, %26, %27}," - " %28," - " %29, %30," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x48x64_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// SPARSE GMMA 64x64x64 TN S32+=U8*U8 -template < - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x64x64_S32U8U8_RS_TN -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[32]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %39, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n64k64.s32.u8.u8 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31}," - "{%32, %33, %34, %35}," - " %36," - " %37, %38," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x64x64_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// SPARSE GMMA 64x64x64 TN S32+=U8*U8 -template < - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x64x64_S32U8U8_RS_TN_SATURATE -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[32]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %39, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n64k64.s32.u8.u8.satfinite " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31}," - "{%32, %33, %34, %35}," - " %36," - " %37, %38," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x64x64_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x80x64 TN S32+=U8*U8 -template < - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x80x64_S32U8U8_RS_TN -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[40]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %47, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n80k64.s32.u8.u8 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39}," - "{%40, %41, %42, %43}," - " %44," - " %45, %46," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x80x64_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x80x64 TN S32+=U8*U8 -template < - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x80x64_S32U8U8_RS_TN_SATURATE -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[40]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %47, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n80k64.s32.u8.u8.satfinite " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39}," - "{%40, %41, %42, %43}," - " %44," - " %45, %46," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x80x64_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// SPARSE GMMA 64x96x64 TN S32+=U8*U8 -template < - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x96x64_S32U8U8_RS_TN -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[48]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %55, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n96k64.s32.u8.u8 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47}," - "{%48, %49, %50, %51}," - " %52," - " %53, %54," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), - "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x96x64_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// SPARSE GMMA 64x96x64 TN S32+=U8*U8 -template < - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x96x64_S32U8U8_RS_TN_SATURATE -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[48]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %55, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n96k64.s32.u8.u8.satfinite " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47}," - "{%48, %49, %50, %51}," - " %52," - " %53, %54," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), - "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x96x64_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x112x64 TN S32+=U8*U8 -template < - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x112x64_S32U8U8_RS_TN -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[56]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, - uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, - uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %63, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n112k64.s32.u8.u8 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55}," - "{%56, %57, %58, %59}," - " %60," - " %61, %62," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), - "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), - "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), - "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x112x64_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x112x64 TN S32+=U8*U8 -template < - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x112x64_S32U8U8_RS_TN_SATURATE -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[56]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, - uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, - uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %63, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n112k64.s32.u8.u8.satfinite " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55}," - "{%56, %57, %58, %59}," - " %60," - " %61, %62," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), - "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), - "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), - "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x112x64_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// SPARSE GMMA 64x128x64 TN S32+=U8*U8 -template < - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x128x64_S32U8U8_RS_TN -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[64]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, - uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, - uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, - uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, - uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %71, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n128k64.s32.u8.u8 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63}," - "{%64, %65, %66, %67}," - " %68," - " %69, %70," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), - "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), - "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), - "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), - "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), - "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x128x64_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// SPARSE GMMA 64x128x64 TN S32+=U8*U8 -template < - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x128x64_S32U8U8_RS_TN_SATURATE -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[64]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, - uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, - uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, - uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, - uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %71, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n128k64.s32.u8.u8.satfinite " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63}," - "{%64, %65, %66, %67}," - " %68," - " %69, %70," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), - "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), - "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), - "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), - "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), - "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x128x64_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x144x64 TN S32+=U8*U8 -template < - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x144x64_S32U8U8_RS_TN -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[72]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, - uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, - uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, - uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, - uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, - uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, - uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %79, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n144k64.s32.u8.u8 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71}," - "{%72, %73, %74, %75}," - " %76," - " %77, %78," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), - "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), - "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), - "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), - "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), - "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), - "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), - "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x144x64_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x144x64 TN S32+=U8*U8 -template < - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x144x64_S32U8U8_RS_TN_SATURATE -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[72]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, - uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, - uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, - uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, - uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, - uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, - uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %79, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n144k64.s32.u8.u8.satfinite " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71}," - "{%72, %73, %74, %75}," - " %76," - " %77, %78," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), - "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), - "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), - "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), - "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), - "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), - "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), - "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x144x64_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x160x64 TN S32+=U8*U8 -template < - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x160x64_S32U8U8_RS_TN -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[80]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, - uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, - uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, - uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, - uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, - uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, - uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, - uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, - uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %87, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n160k64.s32.u8.u8 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79}," - "{%80, %81, %82, %83}," - " %84," - " %85, %86," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), - "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), - "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), - "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), - "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), - "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), - "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), - "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), - "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), - "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x160x64_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x160x64 TN S32+=U8*U8 -template < - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x160x64_S32U8U8_RS_TN_SATURATE -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[80]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, - uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, - uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, - uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, - uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, - uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, - uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, - uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, - uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %87, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n160k64.s32.u8.u8.satfinite " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79}," - "{%80, %81, %82, %83}," - " %84," - " %85, %86," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), - "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), - "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), - "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), - "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), - "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), - "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), - "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), - "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), - "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x160x64_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x176x64 TN S32+=U8*U8 -template < - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x176x64_S32U8U8_RS_TN -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[88]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, - uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, - uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, - uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, - uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, - uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, - uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, - uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, - uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, - uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, - uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %95, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n176k64.s32.u8.u8 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87}," - "{%88, %89, %90, %91}," - " %92," - " %93, %94," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), - "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), - "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), - "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), - "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), - "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), - "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), - "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), - "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), - "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), - "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), - "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x176x64_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x176x64 TN S32+=U8*U8 -template < - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x176x64_S32U8U8_RS_TN_SATURATE -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[88]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, - uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, - uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, - uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, - uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, - uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, - uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, - uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, - uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, - uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, - uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %95, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n176k64.s32.u8.u8.satfinite " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87}," - "{%88, %89, %90, %91}," - " %92," - " %93, %94," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), - "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), - "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), - "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), - "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), - "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), - "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), - "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), - "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), - "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), - "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), - "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x176x64_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// SPARSE GMMA 64x192x64 TN S32+=U8*U8 -template < - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x192x64_S32U8U8_RS_TN -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[96]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, - uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, - uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, - uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, - uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, - uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, - uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, - uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, - uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, - uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, - uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, - uint32_t & d88, uint32_t & d89, uint32_t & d90, uint32_t & d91, - uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %103, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n192k64.s32.u8.u8 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87, " - " %88, %89, %90, %91, %92, %93, %94, %95}," - "{%96, %97, %98, %99}," - " %100," - " %101, %102," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), - "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), - "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), - "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), - "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), - "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), - "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), - "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), - "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), - "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), - "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), - "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87), - "+r"(d88), "+r"(d89), "+r"(d90), "+r"(d91), - "+r"(d92), "+r"(d93), "+r"(d94), "+r"(d95) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x192x64_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// SPARSE GMMA 64x192x64 TN S32+=U8*U8 -template < - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x192x64_S32U8U8_RS_TN_SATURATE -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[96]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, - uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, - uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, - uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, - uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, - uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, - uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, - uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, - uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, - uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, - uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, - uint32_t & d88, uint32_t & d89, uint32_t & d90, uint32_t & d91, - uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %103, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n192k64.s32.u8.u8.satfinite " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87, " - " %88, %89, %90, %91, %92, %93, %94, %95}," - "{%96, %97, %98, %99}," - " %100," - " %101, %102," - " p;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), - "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), - "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), - "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), - "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), - "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), - "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), - "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), - "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), - "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), - "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), - "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87), - "+r"(d88), "+r"(d89), "+r"(d90), "+r"(d91), - "+r"(d92), "+r"(d93), "+r"(d94), "+r"(d95) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x192x64_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x208x64 TN S32+=U8*U8 -template < - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x208x64_S32U8U8_RS_TN -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[104]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, - uint64_t const& desc_b, - uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, - uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, - uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, - uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, - uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, - uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, - uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, - uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, - uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, - uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, - uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, - uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, - uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, - uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, - uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, - uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, - uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, - uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, - uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, - uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, - uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, - uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, - uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, - uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, - uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, - uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %111, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n208k64.s32.u8.u8 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87, " - " %88, %89, %90, %91, %92, %93, %94, %95, " - " %96, %97, %98, %99, %100, %101, %102, %103}," - "{%104, %105, %106, %107}," - " %108," - " %109, %110," - " p;\n" - "}\n" - : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), - "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), - "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), - "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), - "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), - "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), - "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), - "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), - "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), - "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), - "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), - "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), - "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), - "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), - "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), - "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), - "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), - "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), - "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), - "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), - "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), - "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), - "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), - "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), - "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), - "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103) - : "r"(a000), "r"(a001), "r"(a002), "r"(a003), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x208x64_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x208x64 TN S32+=U8*U8 -template < - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x208x64_S32U8U8_RS_TN_SATURATE -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[104]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, - uint64_t const& desc_b, - uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, - uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, - uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, - uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, - uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, - uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, - uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, - uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, - uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, - uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, - uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, - uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, - uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, - uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, - uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, - uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, - uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, - uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, - uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, - uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, - uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, - uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, - uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, - uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, - uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, - uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %111, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n208k64.s32.u8.u8.satfinite " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87, " - " %88, %89, %90, %91, %92, %93, %94, %95, " - " %96, %97, %98, %99, %100, %101, %102, %103}," - "{%104, %105, %106, %107}," - " %108," - " %109, %110," - " p;\n" - "}\n" - : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), - "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), - "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), - "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), - "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), - "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), - "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), - "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), - "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), - "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), - "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), - "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), - "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), - "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), - "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), - "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), - "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), - "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), - "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), - "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), - "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), - "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), - "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), - "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), - "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), - "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103) - : "r"(a000), "r"(a001), "r"(a002), "r"(a003), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x208x64_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x224x64 TN S32+=U8*U8 -template < - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x224x64_S32U8U8_RS_TN -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[112]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, - uint64_t const& desc_b, - uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, - uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, - uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, - uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, - uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, - uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, - uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, - uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, - uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, - uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, - uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, - uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, - uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, - uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, - uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, - uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, - uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, - uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, - uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, - uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, - uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, - uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, - uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, - uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, - uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, - uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, - uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, - uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %119, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n224k64.s32.u8.u8 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87, " - " %88, %89, %90, %91, %92, %93, %94, %95, " - " %96, %97, %98, %99, %100, %101, %102, %103, " - " %104, %105, %106, %107, %108, %109, %110, %111}," - "{%112, %113, %114, %115}," - " %116," - " %117, %118," - " p;\n" - "}\n" - : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), - "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), - "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), - "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), - "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), - "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), - "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), - "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), - "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), - "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), - "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), - "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), - "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), - "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), - "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), - "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), - "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), - "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), - "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), - "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), - "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), - "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), - "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), - "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), - "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), - "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), - "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), - "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111) - : "r"(a000), "r"(a001), "r"(a002), "r"(a003), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x224x64_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x224x64 TN S32+=U8*U8 -template < - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x224x64_S32U8U8_RS_TN_SATURATE -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[112]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, - uint64_t const& desc_b, - uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, - uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, - uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, - uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, - uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, - uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, - uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, - uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, - uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, - uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, - uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, - uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, - uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, - uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, - uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, - uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, - uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, - uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, - uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, - uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, - uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, - uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, - uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, - uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, - uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, - uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, - uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, - uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %119, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n224k64.s32.u8.u8.satfinite " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87, " - " %88, %89, %90, %91, %92, %93, %94, %95, " - " %96, %97, %98, %99, %100, %101, %102, %103, " - " %104, %105, %106, %107, %108, %109, %110, %111}," - "{%112, %113, %114, %115}," - " %116," - " %117, %118," - " p;\n" - "}\n" - : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), - "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), - "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), - "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), - "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), - "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), - "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), - "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), - "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), - "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), - "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), - "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), - "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), - "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), - "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), - "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), - "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), - "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), - "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), - "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), - "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), - "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), - "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), - "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), - "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), - "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), - "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), - "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111) - : "r"(a000), "r"(a001), "r"(a002), "r"(a003), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x224x64_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x240x64 TN S32+=U8*U8 -template < - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x240x64_S32U8U8_RS_TN -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[120]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, - uint64_t const& desc_b, - uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, - uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, - uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, - uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, - uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, - uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, - uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, - uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, - uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, - uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, - uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, - uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, - uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, - uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, - uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, - uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, - uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, - uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, - uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, - uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, - uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, - uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, - uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, - uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, - uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, - uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, - uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, - uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, - uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, - uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %127, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n240k64.s32.u8.u8 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87, " - " %88, %89, %90, %91, %92, %93, %94, %95, " - " %96, %97, %98, %99, %100, %101, %102, %103, " - " %104, %105, %106, %107, %108, %109, %110, %111, " - " %112, %113, %114, %115, %116, %117, %118, %119}," - "{%120, %121, %122, %123}," - " %124," - " %125, %126," - " p;\n" - "}\n" - : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), - "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), - "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), - "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), - "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), - "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), - "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), - "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), - "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), - "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), - "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), - "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), - "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), - "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), - "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), - "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), - "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), - "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), - "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), - "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), - "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), - "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), - "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), - "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), - "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), - "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), - "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), - "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), - "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), - "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119) - : "r"(a000), "r"(a001), "r"(a002), "r"(a003), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x240x64_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x240x64 TN S32+=U8*U8 -template < - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x240x64_S32U8U8_RS_TN_SATURATE -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[120]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, - uint64_t const& desc_b, - uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, - uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, - uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, - uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, - uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, - uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, - uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, - uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, - uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, - uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, - uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, - uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, - uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, - uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, - uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, - uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, - uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, - uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, - uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, - uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, - uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, - uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, - uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, - uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, - uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, - uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, - uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, - uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, - uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, - uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %127, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n240k64.s32.u8.u8.satfinite " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87, " - " %88, %89, %90, %91, %92, %93, %94, %95, " - " %96, %97, %98, %99, %100, %101, %102, %103, " - " %104, %105, %106, %107, %108, %109, %110, %111, " - " %112, %113, %114, %115, %116, %117, %118, %119}," - "{%120, %121, %122, %123}," - " %124," - " %125, %126," - " p;\n" - "}\n" - : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), - "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), - "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), - "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), - "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), - "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), - "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), - "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), - "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), - "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), - "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), - "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), - "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), - "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), - "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), - "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), - "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), - "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), - "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), - "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), - "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), - "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), - "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), - "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), - "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), - "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), - "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), - "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), - "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), - "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119) - : "r"(a000), "r"(a001), "r"(a002), "r"(a003), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x240x64_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// SPARSE GMMA 64x256x64 TN S32+=U8*U8 -template < - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x256x64_S32U8U8_RS_TN -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[128]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, - uint64_t const& desc_b, - uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, - uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, - uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, - uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, - uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, - uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, - uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, - uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, - uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, - uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, - uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, - uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, - uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, - uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, - uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, - uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, - uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, - uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, - uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, - uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, - uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, - uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, - uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, - uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, - uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, - uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, - uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, - uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, - uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, - uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, - uint32_t & d120, uint32_t & d121, uint32_t & d122, uint32_t & d123, - uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %135, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n256k64.s32.u8.u8 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87, " - " %88, %89, %90, %91, %92, %93, %94, %95, " - " %96, %97, %98, %99, %100, %101, %102, %103, " - " %104, %105, %106, %107, %108, %109, %110, %111, " - " %112, %113, %114, %115, %116, %117, %118, %119, " - " %120, %121, %122, %123, %124, %125, %126, %127}," - "{%128, %129, %130, %131}," - " %132," - " %133, %134," - " p;\n" - "}\n" - : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), - "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), - "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), - "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), - "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), - "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), - "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), - "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), - "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), - "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), - "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), - "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), - "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), - "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), - "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), - "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), - "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), - "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), - "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), - "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), - "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), - "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), - "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), - "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), - "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), - "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), - "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), - "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), - "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), - "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119), - "+r"(d120), "+r"(d121), "+r"(d122), "+r"(d123), - "+r"(d124), "+r"(d125), "+r"(d126), "+r"(d127) - : "r"(a000), "r"(a001), "r"(a002), "r"(a003), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x256x64_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// SPARSE GMMA 64x256x64 TN S32+=U8*U8 -template < - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x256x64_S32U8U8_RS_TN_SATURATE -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[128]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, - uint64_t const& desc_b, - uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, - uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, - uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, - uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, - uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, - uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, - uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, - uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, - uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, - uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, - uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, - uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, - uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, - uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, - uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, - uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, - uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, - uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, - uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, - uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, - uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, - uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, - uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, - uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, - uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, - uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, - uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, - uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, - uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, - uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, - uint32_t & d120, uint32_t & d121, uint32_t & d122, uint32_t & d123, - uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %135, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n256k64.s32.u8.u8.satfinite " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87, " - " %88, %89, %90, %91, %92, %93, %94, %95, " - " %96, %97, %98, %99, %100, %101, %102, %103, " - " %104, %105, %106, %107, %108, %109, %110, %111, " - " %112, %113, %114, %115, %116, %117, %118, %119, " - " %120, %121, %122, %123, %124, %125, %126, %127}," - "{%128, %129, %130, %131}," - " %132," - " %133, %134," - " p;\n" - "}\n" - : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), - "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), - "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), - "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), - "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), - "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), - "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), - "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), - "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), - "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), - "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), - "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), - "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), - "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), - "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), - "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), - "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), - "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), - "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), - "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), - "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), - "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), - "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), - "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), - "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), - "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), - "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), - "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), - "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), - "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119), - "+r"(d120), "+r"(d121), "+r"(d122), "+r"(d123), - "+r"(d124), "+r"(d125), "+r"(d126), "+r"(d127) - : "r"(a000), "r"(a001), "r"(a002), "r"(a003), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x256x64_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// SPARSE GMMA 64x8x64 TN F16+=E4M3*E4M3 -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x8x64_F16E4M3E4M3_SS_TN -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[2]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - uint32_t & d0, uint32_t & d1, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %6, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n8k64.f16.e4m3.e4m3 " - "{%0, %1}," - " %2," - " %3," - " %4, %5," - " p, %7, %8;\n" - "}\n" - : "+r"(d0), "+r"(d1) - : "l"(desc_a), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x8x64_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// SPARSE GMMA 64x8x64 TN F16+=E4M3*E4M3 -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x8x64_F16E4M3E4M3_RS_TN -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[2]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, - uint64_t const& desc_b, - uint32_t & d0, uint32_t & d1, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %9, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n8k64.f16.e4m3.e4m3 " - "{%0, %1}," - "{%2, %3, %4, %5}," - " %6," - " %7, %8," - " p, %10, %11;\n" - "}\n" - : "+r"(d0), "+r"(d1) - : "r"(a0), "r"(a1), "r"(a2), "r"(a3), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x8x64_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// SPARSE GMMA 64x8x64 TN F32+=E4M3*E4M3 -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x8x64_F32E4M3E4M3_SS_TN -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = float[4]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - float & d0, float & d1, float & d2, float & d3, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %8, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n8k64.f32.e4m3.e4m3 " - "{%0, %1, %2, %3}," - " %4," - " %5," - " %6, %7," - " p, %9, %10;\n" - "}\n" - : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3) - : "l"(desc_a), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x8x64_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// SPARSE GMMA 64x8x64 TN F32+=E4M3*E4M3 -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x8x64_F32E4M3E4M3_RS_TN -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = float[4]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, - uint64_t const& desc_b, - float & d0, float & d1, float & d2, float & d3, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %11, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n8k64.f32.e4m3.e4m3 " - "{%0, %1, %2, %3}," - "{%4, %5, %6, %7}," - " %8," - " %9, %10," - " p, %12, %13;\n" - "}\n" - : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3) - : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), + "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91), + "+f"(d92), "+f"(d93), "+f"(d94), "+f"(d95) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), "l"(desc_b), "r"(e), "n"(int32_t(spsel)), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x8x64_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x192x16_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// SPARSE GMMA 64x16x64 TN F16+=E4M3*E4M3 +// SPARSE GMMA 64x256x16 TN F32+=TF32*TF32 template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, GMMA::SparseSel spsel = GMMA::SparseSel::Zero > -struct GMMA_64x16x64_F16E4M3E4M3_SS_TN +struct GMMA_64x256x16_F32TF32TF32_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; using ERegisters = uint32_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[4]; + using CRegisters = float[128]; CUTE_HOST_DEVICE static void fma(uint64_t const& desc_a, uint64_t const& desc_b, - uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + float & d120, float & d121, float & d122, float & d123, + float & d124, float & d125, float & d126, float & d127, uint32_t const& e, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %8, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n16k64.f16.e4m3.e4m3 " - "{%0, %1, %2, %3}," - " %4," - " %5," - " %6, %7," - " p, %9, %10;\n" + "setp.ne.b32 p, %132, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n256k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + " %128," + " %129," + " %130, %131," + " p, %133, %134;\n" "}\n" - : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), + "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123), + "+f"(d124), "+f"(d125), "+f"(d126), "+f"(d127) : "l"(desc_a), "l"(desc_b), "r"(e), "n"(int32_t(spsel)), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x16x64_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x256x16_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// SPARSE GMMA 64x16x64 TN F16+=E4M3*E4M3 +// SPARSE GMMA 64x256x16 TN F32+=TF32*TF32 template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, GMMA::SparseSel spsel = GMMA::SparseSel::Zero > -struct GMMA_64x16x64_F16E4M3E4M3_RS_TN +struct GMMA_64x256x16_F32TF32TF32_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; using ERegisters = uint32_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[4]; + using CRegisters = float[128]; CUTE_HOST_DEVICE static void - fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, uint64_t const& desc_b, - uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + float & d120, float & d121, float & d122, float & d123, + float & d124, float & d125, float & d126, float & d127, uint32_t const& e, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %11, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n16k64.f16.e4m3.e4m3 " - "{%0, %1, %2, %3}," - "{%4, %5, %6, %7}," - " %8," - " %9, %10," - " p, %12, %13;\n" + "setp.ne.b32 p, %135, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n256k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + "{%128, %129, %130, %131}," + " %132," + " %133, %134," + " p, %136, %137;\n" "}\n" - : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) - : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), + "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123), + "+f"(d124), "+f"(d125), "+f"(d126), "+f"(d127) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), "l"(desc_b), "r"(e), "n"(int32_t(spsel)), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x16x64_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x256x16_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// SPARSE GMMA 64x16x64 TN F32+=E4M3*E4M3 +// SPARSE GMMA 64x8x64 TN S32+=S8*S8 template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, GMMA::SparseSel spsel = GMMA::SparseSel::Zero > -struct GMMA_64x16x64_F32E4M3E4M3_SS_TN +struct GMMA_64x8x64_S32S8S8_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; using ERegisters = uint32_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = float[8]; + using CRegisters = uint32_t[4]; CUTE_HOST_DEVICE static void fma(uint64_t const& desc_a, uint64_t const& desc_b, - float & d0, float & d1, float & d2, float & d3, - float & d4, float & d5, float & d6, float & d7, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, uint32_t const& e, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %12, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n16k64.f32.e4m3.e4m3 " - "{%0, %1, %2, %3, %4, %5, %6, %7}," - " %8," - " %9," - " %10, %11," - " p, %13, %14;\n" + "setp.ne.b32 p, %8, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n8k64.s32.s8.s8 " + "{%0, %1, %2, %3}," + " %4," + " %5," + " %6, %7," + " p;\n" "}\n" - : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3), - "+f"(d4), "+f"(d5), "+f"(d6), "+f"(d7) + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) : "l"(desc_a), "l"(desc_b), "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); + "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x16x64_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x8x64_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// SPARSE GMMA 64x16x64 TN F32+=E4M3*E4M3 +// SPARSE GMMA 64x8x64 TN S32+=S8*S8 template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, GMMA::SparseSel spsel = GMMA::SparseSel::Zero > -struct GMMA_64x16x64_F32E4M3E4M3_RS_TN +struct GMMA_64x8x64_S32S8S8_SS_TN_SATURATE { using DRegisters = void; - using ARegisters = uint32_t[4]; + using ARegisters = uint64_t[1]; using ERegisters = uint32_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = float[8]; + using CRegisters = uint32_t[4]; CUTE_HOST_DEVICE static void - fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + fma(uint64_t const& desc_a, uint64_t const& desc_b, - float & d0, float & d1, float & d2, float & d3, - float & d4, float & d5, float & d6, float & d7, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, uint32_t const& e, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %15, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n16k64.f32.e4m3.e4m3 " - "{%0, %1, %2, %3, %4, %5, %6, %7}," - "{%8, %9, %10, %11}," - " %12," - " %13, %14," - " p, %16, %17;\n" + "setp.ne.b32 p, %8, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n8k64.s32.s8.s8.satfinite " + "{%0, %1, %2, %3}," + " %4," + " %5," + " %6, %7," + " p;\n" "}\n" - : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3), - "+f"(d4), "+f"(d5), "+f"(d6), "+f"(d7) - : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) + : "l"(desc_a), "l"(desc_b), "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); + "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x16x64_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x8x64_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// SPARSE GMMA 64x32x64 TN F16+=E4M3*E4M3 +// SPARSE GMMA 64x16x64 TN S32+=S8*S8 template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, GMMA::SparseSel spsel = GMMA::SparseSel::Zero > -struct GMMA_64x32x64_F16E4M3E4M3_SS_TN +struct GMMA_64x16x64_S32S8S8_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -33950,47 +4877,46 @@ struct GMMA_64x32x64_F16E4M3E4M3_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" "setp.ne.b32 p, %12, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n32k64.f16.e4m3.e4m3 " + "wgmma.mma_async.sp.sync.aligned.m64n16k64.s32.s8.s8 " "{%0, %1, %2, %3, %4, %5, %6, %7}," " %8," " %9," " %10, %11," - " p, %13, %14;\n" + " p;\n" "}\n" : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) : "l"(desc_a), "l"(desc_b), "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); + "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x32x64_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x16x64_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// SPARSE GMMA 64x32x64 TN F16+=E4M3*E4M3 +// SPARSE GMMA 64x16x64 TN S32+=S8*S8 template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, GMMA::SparseSel spsel = GMMA::SparseSel::Zero > -struct GMMA_64x32x64_F16E4M3E4M3_RS_TN +struct GMMA_64x16x64_S32S8S8_SS_TN_SATURATE { using DRegisters = void; - using ARegisters = uint32_t[4]; + using ARegisters = uint64_t[1]; using ERegisters = uint32_t[1]; using BRegisters = uint64_t[1]; using CRegisters = uint32_t[8]; CUTE_HOST_DEVICE static void - fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + fma(uint64_t const& desc_a, uint64_t const& desc_b, uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, @@ -33998,151 +4924,95 @@ struct GMMA_64x32x64_F16E4M3E4M3_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %15, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n32k64.f16.e4m3.e4m3 " + "setp.ne.b32 p, %12, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n16k64.s32.s8.s8.satfinite " "{%0, %1, %2, %3, %4, %5, %6, %7}," - "{%8, %9, %10, %11}," - " %12," - " %13, %14," - " p, %16, %17;\n" + " %8," + " %9," + " %10, %11," + " p;\n" "}\n" : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) - : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + : "l"(desc_a), "l"(desc_b), "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); + "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x32x64_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x16x64_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// SPARSE GMMA 64x32x64 TN F32+=E4M3*E4M3 +// SPARSE GMMA 64x32x64 TN S32+=S8*S8 template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, GMMA::SparseSel spsel = GMMA::SparseSel::Zero > -struct GMMA_64x32x64_F32E4M3E4M3_SS_TN +struct GMMA_64x32x64_S32S8S8_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; using ERegisters = uint32_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = float[16]; + using CRegisters = uint32_t[16]; CUTE_HOST_DEVICE static void fma(uint64_t const& desc_a, uint64_t const& desc_b, - float & d00, float & d01, float & d02, float & d03, - float & d04, float & d05, float & d06, float & d07, - float & d08, float & d09, float & d10, float & d11, - float & d12, float & d13, float & d14, float & d15, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, uint32_t const& e, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" "setp.ne.b32 p, %20, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n32k64.f32.e4m3.e4m3 " + "wgmma.mma_async.sp.sync.aligned.m64n32k64.s32.s8.s8 " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15}," " %16," " %17," " %18, %19," - " p, %21, %22;\n" + " p;\n" "}\n" - : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), - "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), - "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), - "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15) + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) : "l"(desc_a), "l"(desc_b), "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x32x64_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// SPARSE GMMA 64x32x64 TN F32+=E4M3*E4M3 -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x32x64_F32E4M3E4M3_RS_TN -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = float[16]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, - uint64_t const& desc_b, - float & d00, float & d01, float & d02, float & d03, - float & d04, float & d05, float & d06, float & d07, - float & d08, float & d09, float & d10, float & d11, - float & d12, float & d13, float & d14, float & d15, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %23, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n32k64.f32.e4m3.e4m3 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15}," - "{%16, %17, %18, %19}," - " %20," - " %21, %22," - " p, %24, %25;\n" - "}\n" - : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), - "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), - "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), - "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); + "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x32x64_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x32x64_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x48x64 TN F16+=E4M3*E4M3 +// SPARSE GMMA 64x32x64 TN S32+=S8*S8 template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, GMMA::SparseSel spsel = GMMA::SparseSel::Zero > -struct GMMA_64x48x64_F16E4M3E4M3_SS_TN +struct GMMA_64x32x64_S32S8S8_SS_TN_SATURATE { using DRegisters = void; using ARegisters = uint64_t[1]; using ERegisters = uint32_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[12]; + using CRegisters = uint32_t[16]; CUTE_HOST_DEVICE static void fma(uint64_t const& desc_a, @@ -34150,224 +5020,175 @@ struct GMMA_64x48x64_F16E4M3E4M3_SS_TN uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, uint32_t const& e, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %16, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n48k64.f16.e4m3.e4m3 " + "setp.ne.b32 p, %20, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n32k64.s32.s8.s8.satfinite " "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11}," - " %12," - " %13," - " %14, %15," - " p, %17, %18;\n" + " %8, %9, %10, %11, %12, %13, %14, %15}," + " %16," + " %17," + " %18, %19," + " p;\n" "}\n" : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11) + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) : "l"(desc_a), "l"(desc_b), "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); + "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x48x64_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x32x64_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x48x64 TN F16+=E4M3*E4M3 +// SPARSE GMMA 64x64x64 TN S32+=S8*S8 template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, GMMA::SparseSel spsel = GMMA::SparseSel::Zero > -struct GMMA_64x48x64_F16E4M3E4M3_RS_TN +struct GMMA_64x64x64_S32S8S8_SS_TN { using DRegisters = void; - using ARegisters = uint32_t[4]; + using ARegisters = uint64_t[1]; using ERegisters = uint32_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[12]; + using CRegisters = uint32_t[32]; CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + fma(uint64_t const& desc_a, uint64_t const& desc_b, uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, uint32_t const& e, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %19, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n48k64.f16.e4m3.e4m3 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11}," - "{%12, %13, %14, %15}," - " %16," - " %17, %18," - " p, %20, %21;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x48x64_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x48x64 TN F32+=E4M3*E4M3 -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x48x64_F32E4M3E4M3_SS_TN -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = float[24]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - float & d00, float & d01, float & d02, float & d03, - float & d04, float & d05, float & d06, float & d07, - float & d08, float & d09, float & d10, float & d11, - float & d12, float & d13, float & d14, float & d15, - float & d16, float & d17, float & d18, float & d19, - float & d20, float & d21, float & d22, float & d23, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %28, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n48k64.f32.e4m3.e4m3 " + "setp.ne.b32 p, %36, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n64k64.s32.s8.s8 " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23}," - " %24," - " %25," - " %26, %27," - " p, %29, %30;\n" - "}\n" - : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), - "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), - "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), - "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), - "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), - "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23) + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + " %32," + " %33," + " %34, %35," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) : "l"(desc_a), "l"(desc_b), "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); + "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x48x64_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x64x64_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x48x64 TN F32+=E4M3*E4M3 +// SPARSE GMMA 64x64x64 TN S32+=S8*S8 template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, GMMA::SparseSel spsel = GMMA::SparseSel::Zero > -struct GMMA_64x48x64_F32E4M3E4M3_RS_TN +struct GMMA_64x64x64_S32S8S8_SS_TN_SATURATE { using DRegisters = void; - using ARegisters = uint32_t[4]; + using ARegisters = uint64_t[1]; using ERegisters = uint32_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = float[24]; + using CRegisters = uint32_t[32]; CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + fma(uint64_t const& desc_a, uint64_t const& desc_b, - float & d00, float & d01, float & d02, float & d03, - float & d04, float & d05, float & d06, float & d07, - float & d08, float & d09, float & d10, float & d11, - float & d12, float & d13, float & d14, float & d15, - float & d16, float & d17, float & d18, float & d19, - float & d20, float & d21, float & d22, float & d23, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, uint32_t const& e, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %31, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n48k64.f32.e4m3.e4m3 " + "setp.ne.b32 p, %36, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n64k64.s32.s8.s8.satfinite " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23}," - "{%24, %25, %26, %27}," - " %28," - " %29, %30," - " p, %32, %33;\n" + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + " %32," + " %33," + " %34, %35," + " p;\n" "}\n" - : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), - "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), - "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), - "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), - "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), - "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) + : "l"(desc_a), "l"(desc_b), "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); + "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x48x64_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x64x64_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -// SPARSE GMMA 64x64x64 TN F16+=E4M3*E4M3 +// SPARSE GMMA 64x96x64 TN S32+=S8*S8 template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, GMMA::SparseSel spsel = GMMA::SparseSel::Zero > -struct GMMA_64x64x64_F16E4M3E4M3_SS_TN +struct GMMA_64x96x64_S32S8S8_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; using ERegisters = uint32_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[16]; + using CRegisters = uint32_t[48]; CUTE_HOST_DEVICE static void fma(uint64_t const& desc_a, @@ -34376,231 +5197,306 @@ struct GMMA_64x64x64_F16E4M3E4M3_SS_TN uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, uint32_t const& e, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %20, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n64k64.f16.e4m3.e4m3 " + "setp.ne.b32 p, %52, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n96k64.s32.s8.s8 " "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15}," - " %16," - " %17," - " %18, %19," - " p, %21, %22;\n" + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + " %48," + " %49," + " %50, %51," + " p;\n" "}\n" : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) : "l"(desc_a), "l"(desc_b), "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); + "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x64x64_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x96x64_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// SPARSE GMMA 64x64x64 TN F16+=E4M3*E4M3 +// SPARSE GMMA 64x96x64 TN S32+=S8*S8 template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, GMMA::SparseSel spsel = GMMA::SparseSel::Zero > -struct GMMA_64x64x64_F16E4M3E4M3_RS_TN +struct GMMA_64x96x64_S32S8S8_SS_TN_SATURATE { using DRegisters = void; - using ARegisters = uint32_t[4]; + using ARegisters = uint64_t[1]; using ERegisters = uint32_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[16]; + using CRegisters = uint32_t[48]; CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + fma(uint64_t const& desc_a, uint64_t const& desc_b, uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, uint32_t const& e, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %23, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n64k64.f16.e4m3.e4m3 " + "setp.ne.b32 p, %52, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n96k64.s32.s8.s8.satfinite " "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15}," - "{%16, %17, %18, %19}," - " %20," - " %21, %22," - " p, %24, %25;\n" + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + " %48," + " %49," + " %50, %51," + " p;\n" "}\n" : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) + : "l"(desc_a), "l"(desc_b), "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); + "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x64x64_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x96x64_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// SPARSE GMMA 64x64x64 TN F32+=E4M3*E4M3 +// SPARSE GMMA 64x128x64 TN S32+=S8*S8 template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, GMMA::SparseSel spsel = GMMA::SparseSel::Zero > -struct GMMA_64x64x64_F32E4M3E4M3_SS_TN +struct GMMA_64x128x64_S32S8S8_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; using ERegisters = uint32_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = float[32]; + using CRegisters = uint32_t[64]; CUTE_HOST_DEVICE static void fma(uint64_t const& desc_a, uint64_t const& desc_b, - float & d00, float & d01, float & d02, float & d03, - float & d04, float & d05, float & d06, float & d07, - float & d08, float & d09, float & d10, float & d11, - float & d12, float & d13, float & d14, float & d15, - float & d16, float & d17, float & d18, float & d19, - float & d20, float & d21, float & d22, float & d23, - float & d24, float & d25, float & d26, float & d27, - float & d28, float & d29, float & d30, float & d31, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, uint32_t const& e, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %36, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n64k64.f32.e4m3.e4m3 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31}," - " %32," - " %33," - " %34, %35," - " p, %37, %38;\n" + "setp.ne.b32 p, %68, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n128k64.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + " %64," + " %65," + " %66, %67," + " p;\n" "}\n" - : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), - "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), - "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), - "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), - "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), - "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), - "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), - "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31) + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) : "l"(desc_a), "l"(desc_b), "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); + "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x64x64_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x128x64_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// SPARSE GMMA 64x64x64 TN F32+=E4M3*E4M3 +// SPARSE GMMA 64x128x64 TN S32+=S8*S8 template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, GMMA::SparseSel spsel = GMMA::SparseSel::Zero > -struct GMMA_64x64x64_F32E4M3E4M3_RS_TN +struct GMMA_64x128x64_S32S8S8_SS_TN_SATURATE { using DRegisters = void; - using ARegisters = uint32_t[4]; + using ARegisters = uint64_t[1]; using ERegisters = uint32_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = float[32]; + using CRegisters = uint32_t[64]; CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + fma(uint64_t const& desc_a, uint64_t const& desc_b, - float & d00, float & d01, float & d02, float & d03, - float & d04, float & d05, float & d06, float & d07, - float & d08, float & d09, float & d10, float & d11, - float & d12, float & d13, float & d14, float & d15, - float & d16, float & d17, float & d18, float & d19, - float & d20, float & d21, float & d22, float & d23, - float & d24, float & d25, float & d26, float & d27, - float & d28, float & d29, float & d30, float & d31, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, uint32_t const& e, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %39, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n64k64.f32.e4m3.e4m3 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31}," - "{%32, %33, %34, %35}," - " %36," - " %37, %38," - " p, %40, %41;\n" + "setp.ne.b32 p, %68, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n128k64.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + " %64," + " %65," + " %66, %67," + " p;\n" "}\n" - : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), - "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), - "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), - "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), - "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), - "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), - "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), - "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) + : "l"(desc_a), "l"(desc_b), "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); + "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x64x64_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x128x64_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x80x64 TN F16+=E4M3*E4M3 +// SPARSE GMMA 64x192x64 TN S32+=S8*S8 template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, GMMA::SparseSel spsel = GMMA::SparseSel::Zero > -struct GMMA_64x80x64_F16E4M3E4M3_SS_TN +struct GMMA_64x192x64_S32S8S8_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; using ERegisters = uint32_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[20]; + using CRegisters = uint32_t[96]; CUTE_HOST_DEVICE static void fma(uint64_t const& desc_a, @@ -34610,578 +5506,681 @@ struct GMMA_64x80x64_F16E4M3E4M3_SS_TN uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + uint32_t & d88, uint32_t & d89, uint32_t & d90, uint32_t & d91, + uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95, uint32_t const& e, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %24, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n80k64.f16.e4m3.e4m3 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19}," - " %20," - " %21," - " %22, %23," - " p, %25, %26;\n" + "setp.ne.b32 p, %100, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n192k64.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + " %96," + " %97," + " %98, %99," + " p;\n" "}\n" : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19) + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87), + "+r"(d88), "+r"(d89), "+r"(d90), "+r"(d91), + "+r"(d92), "+r"(d93), "+r"(d94), "+r"(d95) : "l"(desc_a), "l"(desc_b), "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); + "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x80x64_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x192x64_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x80x64 TN F16+=E4M3*E4M3 +// SPARSE GMMA 64x192x64 TN S32+=S8*S8 template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, GMMA::SparseSel spsel = GMMA::SparseSel::Zero > -struct GMMA_64x80x64_F16E4M3E4M3_RS_TN +struct GMMA_64x192x64_S32S8S8_SS_TN_SATURATE { using DRegisters = void; - using ARegisters = uint32_t[4]; + using ARegisters = uint64_t[1]; using ERegisters = uint32_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[20]; + using CRegisters = uint32_t[96]; CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + fma(uint64_t const& desc_a, uint64_t const& desc_b, uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + uint32_t & d88, uint32_t & d89, uint32_t & d90, uint32_t & d91, + uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95, uint32_t const& e, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %27, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n80k64.f16.e4m3.e4m3 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19}," - "{%20, %21, %22, %23}," - " %24," - " %25, %26," - " p, %28, %29;\n" + "setp.ne.b32 p, %100, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n192k64.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + " %96," + " %97," + " %98, %99," + " p;\n" "}\n" : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87), + "+r"(d88), "+r"(d89), "+r"(d90), "+r"(d91), + "+r"(d92), "+r"(d93), "+r"(d94), "+r"(d95) + : "l"(desc_a), "l"(desc_b), "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); + "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x80x64_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x192x64_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x80x64 TN F32+=E4M3*E4M3 +// SPARSE GMMA 64x256x64 TN S32+=S8*S8 template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, GMMA::SparseSel spsel = GMMA::SparseSel::Zero > -struct GMMA_64x80x64_F32E4M3E4M3_SS_TN +struct GMMA_64x256x64_S32S8S8_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; using ERegisters = uint32_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = float[40]; + using CRegisters = uint32_t[128]; CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - float & d00, float & d01, float & d02, float & d03, - float & d04, float & d05, float & d06, float & d07, - float & d08, float & d09, float & d10, float & d11, - float & d12, float & d13, float & d14, float & d15, - float & d16, float & d17, float & d18, float & d19, - float & d20, float & d21, float & d22, float & d23, - float & d24, float & d25, float & d26, float & d27, - float & d28, float & d29, float & d30, float & d31, - float & d32, float & d33, float & d34, float & d35, - float & d36, float & d37, float & d38, float & d39, + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + uint32_t & d120, uint32_t & d121, uint32_t & d122, uint32_t & d123, + uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127, uint32_t const& e, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %44, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n80k64.f32.e4m3.e4m3 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39}," - " %40," - " %41," - " %42, %43," - " p, %45, %46;\n" + "setp.ne.b32 p, %132, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n256k64.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + " %128," + " %129," + " %130, %131," + " p;\n" "}\n" - : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), - "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), - "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), - "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), - "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), - "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), - "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), - "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), - "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), - "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39) + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119), + "+r"(d120), "+r"(d121), "+r"(d122), "+r"(d123), + "+r"(d124), "+r"(d125), "+r"(d126), "+r"(d127) : "l"(desc_a), "l"(desc_b), "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); + "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x80x64_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x256x64_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x80x64 TN F32+=E4M3*E4M3 +// SPARSE GMMA 64x256x64 TN S32+=S8*S8 template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, GMMA::SparseSel spsel = GMMA::SparseSel::Zero > -struct GMMA_64x80x64_F32E4M3E4M3_RS_TN +struct GMMA_64x256x64_S32S8S8_SS_TN_SATURATE { using DRegisters = void; - using ARegisters = uint32_t[4]; + using ARegisters = uint64_t[1]; using ERegisters = uint32_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = float[40]; + using CRegisters = uint32_t[128]; CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + fma(uint64_t const& desc_a, uint64_t const& desc_b, - float & d00, float & d01, float & d02, float & d03, - float & d04, float & d05, float & d06, float & d07, - float & d08, float & d09, float & d10, float & d11, - float & d12, float & d13, float & d14, float & d15, - float & d16, float & d17, float & d18, float & d19, - float & d20, float & d21, float & d22, float & d23, - float & d24, float & d25, float & d26, float & d27, - float & d28, float & d29, float & d30, float & d31, - float & d32, float & d33, float & d34, float & d35, - float & d36, float & d37, float & d38, float & d39, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + uint32_t & d120, uint32_t & d121, uint32_t & d122, uint32_t & d123, + uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127, uint32_t const& e, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %47, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n80k64.f32.e4m3.e4m3 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39}," - "{%40, %41, %42, %43}," - " %44," - " %45, %46," - " p, %48, %49;\n" + "setp.ne.b32 p, %132, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n256k64.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + " %128," + " %129," + " %130, %131," + " p;\n" "}\n" - : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), - "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), - "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), - "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), - "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), - "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), - "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), - "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), - "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), - "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119), + "+r"(d120), "+r"(d121), "+r"(d122), "+r"(d123), + "+r"(d124), "+r"(d125), "+r"(d126), "+r"(d127) + : "l"(desc_a), "l"(desc_b), "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); + "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x80x64_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x256x64_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -// SPARSE GMMA 64x96x64 TN F16+=E4M3*E4M3 +// SPARSE GMMA 64x8x64 TN S32+=S8*S8 template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, GMMA::SparseSel spsel = GMMA::SparseSel::Zero > -struct GMMA_64x96x64_F16E4M3E4M3_SS_TN +struct GMMA_64x8x64_S32S8S8_RS_TN { using DRegisters = void; - using ARegisters = uint64_t[1]; + using ARegisters = uint32_t[4]; using ERegisters = uint32_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[24]; + using CRegisters = uint32_t[4]; CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, uint32_t const& e, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %28, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n96k64.f16.e4m3.e4m3 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23}," - " %24," - " %25," - " %26, %27," - " p, %29, %30;\n" + "setp.ne.b32 p, %11, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n8k64.s32.s8.s8 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + " %8," + " %9, %10," + " p;\n" "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) - : "l"(desc_a), + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), "l"(desc_b), "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); + "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x96x64_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x8x64_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// SPARSE GMMA 64x96x64 TN F16+=E4M3*E4M3 +// SPARSE GMMA 64x8x64 TN S32+=S8*S8 template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, GMMA::SparseSel spsel = GMMA::SparseSel::Zero > -struct GMMA_64x96x64_F16E4M3E4M3_RS_TN +struct GMMA_64x8x64_S32S8S8_RS_TN_SATURATE { using DRegisters = void; using ARegisters = uint32_t[4]; using ERegisters = uint32_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[24]; + using CRegisters = uint32_t[4]; CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, uint32_t const& e, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %31, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n96k64.f16.e4m3.e4m3 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23}," - "{%24, %25, %26, %27}," - " %28," - " %29, %30," - " p, %32, %33;\n" + "setp.ne.b32 p, %11, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n8k64.s32.s8.s8.satfinite " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + " %8," + " %9, %10," + " p;\n" "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), "l"(desc_b), "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); + "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x96x64_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x8x64_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// SPARSE GMMA 64x96x64 TN F32+=E4M3*E4M3 +// SPARSE GMMA 64x16x64 TN S32+=S8*S8 template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, GMMA::SparseSel spsel = GMMA::SparseSel::Zero > -struct GMMA_64x96x64_F32E4M3E4M3_SS_TN +struct GMMA_64x16x64_S32S8S8_RS_TN { using DRegisters = void; - using ARegisters = uint64_t[1]; + using ARegisters = uint32_t[4]; using ERegisters = uint32_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = float[48]; + using CRegisters = uint32_t[8]; CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, uint64_t const& desc_b, - float & d00, float & d01, float & d02, float & d03, - float & d04, float & d05, float & d06, float & d07, - float & d08, float & d09, float & d10, float & d11, - float & d12, float & d13, float & d14, float & d15, - float & d16, float & d17, float & d18, float & d19, - float & d20, float & d21, float & d22, float & d23, - float & d24, float & d25, float & d26, float & d27, - float & d28, float & d29, float & d30, float & d31, - float & d32, float & d33, float & d34, float & d35, - float & d36, float & d37, float & d38, float & d39, - float & d40, float & d41, float & d42, float & d43, - float & d44, float & d45, float & d46, float & d47, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, uint32_t const& e, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %52, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n96k64.f32.e4m3.e4m3 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47}," - " %48," - " %49," - " %50, %51," - " p, %53, %54;\n" + "setp.ne.b32 p, %15, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n16k64.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + "{%8, %9, %10, %11}," + " %12," + " %13, %14," + " p;\n" "}\n" - : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), - "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), - "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), - "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), - "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), - "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), - "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), - "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), - "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), - "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), - "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), - "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47) - : "l"(desc_a), + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), "l"(desc_b), "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); + "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x96x64_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x16x64_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// SPARSE GMMA 64x96x64 TN F32+=E4M3*E4M3 +// SPARSE GMMA 64x16x64 TN S32+=S8*S8 template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, GMMA::SparseSel spsel = GMMA::SparseSel::Zero > -struct GMMA_64x96x64_F32E4M3E4M3_RS_TN +struct GMMA_64x16x64_S32S8S8_RS_TN_SATURATE { using DRegisters = void; using ARegisters = uint32_t[4]; using ERegisters = uint32_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = float[48]; + using CRegisters = uint32_t[8]; CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, uint64_t const& desc_b, - float & d00, float & d01, float & d02, float & d03, - float & d04, float & d05, float & d06, float & d07, - float & d08, float & d09, float & d10, float & d11, - float & d12, float & d13, float & d14, float & d15, - float & d16, float & d17, float & d18, float & d19, - float & d20, float & d21, float & d22, float & d23, - float & d24, float & d25, float & d26, float & d27, - float & d28, float & d29, float & d30, float & d31, - float & d32, float & d33, float & d34, float & d35, - float & d36, float & d37, float & d38, float & d39, - float & d40, float & d41, float & d42, float & d43, - float & d44, float & d45, float & d46, float & d47, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, uint32_t const& e, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %55, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n96k64.f32.e4m3.e4m3 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47}," - "{%48, %49, %50, %51}," - " %52," - " %53, %54," - " p, %56, %57;\n" + "setp.ne.b32 p, %15, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n16k64.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + "{%8, %9, %10, %11}," + " %12," + " %13, %14," + " p;\n" "}\n" - : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), - "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), - "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), - "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), - "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), - "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), - "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), - "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), - "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), - "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), - "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), - "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), "l"(desc_b), "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); + "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x96x64_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x16x64_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x112x64 TN F16+=E4M3*E4M3 +// SPARSE GMMA 64x32x64 TN S32+=S8*S8 template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, GMMA::SparseSel spsel = GMMA::SparseSel::Zero > -struct GMMA_64x112x64_F16E4M3E4M3_SS_TN +struct GMMA_64x32x64_S32S8S8_RS_TN { using DRegisters = void; - using ARegisters = uint64_t[1]; + using ARegisters = uint32_t[4]; using ERegisters = uint32_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[28]; + using CRegisters = uint32_t[16]; CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, uint64_t const& desc_b, uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, uint32_t const& e, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %32, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n112k64.f16.e4m3.e4m3 " + "setp.ne.b32 p, %23, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n32k64.s32.s8.s8 " "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27}," - " %28," - " %29," - " %30, %31," - " p, %33, %34;\n" + " %8, %9, %10, %11, %12, %13, %14, %15}," + "{%16, %17, %18, %19}," + " %20," + " %21, %22," + " p;\n" "}\n" : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27) - : "l"(desc_a), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), "l"(desc_b), "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); + "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x112x64_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x32x64_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x112x64 TN F16+=E4M3*E4M3 +// SPARSE GMMA 64x32x64 TN S32+=S8*S8 template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, GMMA::SparseSel spsel = GMMA::SparseSel::Zero > -struct GMMA_64x112x64_F16E4M3E4M3_RS_TN +struct GMMA_64x32x64_S32S8S8_RS_TN_SATURATE { using DRegisters = void; using ARegisters = uint32_t[4]; using ERegisters = uint32_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[28]; + using CRegisters = uint32_t[16]; CUTE_HOST_DEVICE static void fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, @@ -35190,223 +6189,177 @@ struct GMMA_64x112x64_F16E4M3E4M3_RS_TN uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, uint32_t const& e, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %35, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n112k64.f16.e4m3.e4m3 " + "setp.ne.b32 p, %23, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n32k64.s32.s8.s8.satfinite " "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27}," - "{%28, %29, %30, %31}," - " %32," - " %33, %34," - " p, %36, %37;\n" + " %8, %9, %10, %11, %12, %13, %14, %15}," + "{%16, %17, %18, %19}," + " %20," + " %21, %22," + " p;\n" "}\n" : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27) + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) : "r"(a00), "r"(a01), "r"(a02), "r"(a03), "l"(desc_b), "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); + "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x112x64_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x32x64_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x112x64 TN F32+=E4M3*E4M3 +// SPARSE GMMA 64x64x64 TN S32+=S8*S8 template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, GMMA::SparseSel spsel = GMMA::SparseSel::Zero > -struct GMMA_64x112x64_F32E4M3E4M3_SS_TN +struct GMMA_64x64x64_S32S8S8_RS_TN { using DRegisters = void; - using ARegisters = uint64_t[1]; + using ARegisters = uint32_t[4]; using ERegisters = uint32_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = float[56]; + using CRegisters = uint32_t[32]; CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, uint64_t const& desc_b, - float & d00, float & d01, float & d02, float & d03, - float & d04, float & d05, float & d06, float & d07, - float & d08, float & d09, float & d10, float & d11, - float & d12, float & d13, float & d14, float & d15, - float & d16, float & d17, float & d18, float & d19, - float & d20, float & d21, float & d22, float & d23, - float & d24, float & d25, float & d26, float & d27, - float & d28, float & d29, float & d30, float & d31, - float & d32, float & d33, float & d34, float & d35, - float & d36, float & d37, float & d38, float & d39, - float & d40, float & d41, float & d42, float & d43, - float & d44, float & d45, float & d46, float & d47, - float & d48, float & d49, float & d50, float & d51, - float & d52, float & d53, float & d54, float & d55, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, uint32_t const& e, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %60, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n112k64.f32.e4m3.e4m3 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55}," - " %56," - " %57," - " %58, %59," - " p, %61, %62;\n" + "setp.ne.b32 p, %39, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n64k64.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + "{%32, %33, %34, %35}," + " %36," + " %37, %38," + " p;\n" "}\n" - : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), - "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), - "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), - "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), - "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), - "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), - "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), - "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), - "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), - "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), - "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), - "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), - "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), - "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55) - : "l"(desc_a), + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), "l"(desc_b), "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); + "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x112x64_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x64x64_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x112x64 TN F32+=E4M3*E4M3 +// SPARSE GMMA 64x64x64 TN S32+=S8*S8 template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, GMMA::SparseSel spsel = GMMA::SparseSel::Zero > -struct GMMA_64x112x64_F32E4M3E4M3_RS_TN +struct GMMA_64x64x64_S32S8S8_RS_TN_SATURATE { using DRegisters = void; using ARegisters = uint32_t[4]; using ERegisters = uint32_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = float[56]; + using CRegisters = uint32_t[32]; CUTE_HOST_DEVICE static void fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, uint64_t const& desc_b, - float & d00, float & d01, float & d02, float & d03, - float & d04, float & d05, float & d06, float & d07, - float & d08, float & d09, float & d10, float & d11, - float & d12, float & d13, float & d14, float & d15, - float & d16, float & d17, float & d18, float & d19, - float & d20, float & d21, float & d22, float & d23, - float & d24, float & d25, float & d26, float & d27, - float & d28, float & d29, float & d30, float & d31, - float & d32, float & d33, float & d34, float & d35, - float & d36, float & d37, float & d38, float & d39, - float & d40, float & d41, float & d42, float & d43, - float & d44, float & d45, float & d46, float & d47, - float & d48, float & d49, float & d50, float & d51, - float & d52, float & d53, float & d54, float & d55, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, uint32_t const& e, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %63, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n112k64.f32.e4m3.e4m3 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55}," - "{%56, %57, %58, %59}," - " %60," - " %61, %62," - " p, %64, %65;\n" + "setp.ne.b32 p, %39, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n64k64.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + "{%32, %33, %34, %35}," + " %36," + " %37, %38," + " p;\n" "}\n" - : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), - "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), - "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), - "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), - "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), - "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), - "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), - "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), - "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), - "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), - "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), - "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), - "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), - "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55) + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) : "r"(a00), "r"(a01), "r"(a02), "r"(a03), "l"(desc_b), "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); + "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x112x64_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x64x64_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -// SPARSE GMMA 64x128x64 TN F16+=E4M3*E4M3 +// SPARSE GMMA 64x96x64 TN S32+=S8*S8 template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, GMMA::SparseSel spsel = GMMA::SparseSel::Zero > -struct GMMA_64x128x64_F16E4M3E4M3_SS_TN +struct GMMA_64x96x64_S32S8S8_RS_TN { using DRegisters = void; - using ARegisters = uint64_t[1]; + using ARegisters = uint32_t[4]; using ERegisters = uint32_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[32]; + using CRegisters = uint32_t[48]; CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, uint64_t const& desc_b, uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, @@ -35416,23 +6369,30 @@ struct GMMA_64x128x64_F16E4M3E4M3_SS_TN uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, uint32_t const& e, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %36, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n128k64.f16.e4m3.e4m3 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31}," - " %32," - " %33," - " %34, %35," - " p, %37, %38;\n" + "setp.ne.b32 p, %55, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n96k64.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + "{%48, %49, %50, %51}," + " %52," + " %53, %54," + " p;\n" "}\n" : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), @@ -35441,32 +6401,34 @@ struct GMMA_64x128x64_F16E4M3E4M3_SS_TN "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) - : "l"(desc_a), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), "l"(desc_b), "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); + "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x128x64_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x96x64_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// SPARSE GMMA 64x128x64 TN F16+=E4M3*E4M3 +// SPARSE GMMA 64x96x64 TN S32+=S8*S8 template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, GMMA::SparseSel spsel = GMMA::SparseSel::Zero > -struct GMMA_64x128x64_F16E4M3E4M3_RS_TN +struct GMMA_64x96x64_S32S8S8_RS_TN_SATURATE { using DRegisters = void; using ARegisters = uint32_t[4]; using ERegisters = uint32_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[32]; + using CRegisters = uint32_t[48]; CUTE_HOST_DEVICE static void fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, @@ -35479,23 +6441,30 @@ struct GMMA_64x128x64_F16E4M3E4M3_RS_TN uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, uint32_t const& e, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %39, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n128k64.f16.e4m3.e4m3 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31}," - "{%32, %33, %34, %35}," - " %36," - " %37, %38," - " p, %40, %41;\n" + "setp.ne.b32 p, %55, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n96k64.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + "{%48, %49, %50, %51}," + " %52," + " %53, %54," + " p;\n" "}\n" : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), @@ -35504,61 +6473,64 @@ struct GMMA_64x128x64_F16E4M3E4M3_RS_TN "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) : "r"(a00), "r"(a01), "r"(a02), "r"(a03), "l"(desc_b), "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); + "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x128x64_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x96x64_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// SPARSE GMMA 64x128x64 TN F32+=E4M3*E4M3 +// SPARSE GMMA 64x128x64 TN S32+=S8*S8 template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, GMMA::SparseSel spsel = GMMA::SparseSel::Zero > -struct GMMA_64x128x64_F32E4M3E4M3_SS_TN +struct GMMA_64x128x64_S32S8S8_RS_TN { using DRegisters = void; - using ARegisters = uint64_t[1]; + using ARegisters = uint32_t[4]; using ERegisters = uint32_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = float[64]; + using CRegisters = uint32_t[64]; CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, uint64_t const& desc_b, - float & d00, float & d01, float & d02, float & d03, - float & d04, float & d05, float & d06, float & d07, - float & d08, float & d09, float & d10, float & d11, - float & d12, float & d13, float & d14, float & d15, - float & d16, float & d17, float & d18, float & d19, - float & d20, float & d21, float & d22, float & d23, - float & d24, float & d25, float & d26, float & d27, - float & d28, float & d29, float & d30, float & d31, - float & d32, float & d33, float & d34, float & d35, - float & d36, float & d37, float & d38, float & d39, - float & d40, float & d41, float & d42, float & d43, - float & d44, float & d45, float & d46, float & d47, - float & d48, float & d49, float & d50, float & d51, - float & d52, float & d53, float & d54, float & d55, - float & d56, float & d57, float & d58, float & d59, - float & d60, float & d61, float & d62, float & d63, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, uint32_t const& e, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %68, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n128k64.f32.e4m3.e4m3 " + "setp.ne.b32 p, %71, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n128k64.s32.s8.s8 " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " " %16, %17, %18, %19, %20, %21, %22, %23, " @@ -35567,81 +6539,80 @@ struct GMMA_64x128x64_F32E4M3E4M3_SS_TN " %40, %41, %42, %43, %44, %45, %46, %47, " " %48, %49, %50, %51, %52, %53, %54, %55, " " %56, %57, %58, %59, %60, %61, %62, %63}," - " %64," - " %65," - " %66, %67," - " p, %69, %70;\n" + "{%64, %65, %66, %67}," + " %68," + " %69, %70," + " p;\n" "}\n" - : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), - "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), - "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), - "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), - "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), - "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), - "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), - "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), - "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), - "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), - "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), - "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), - "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), - "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), - "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), - "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63) - : "l"(desc_a), + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), "l"(desc_b), "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); + "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x128x64_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x128x64_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// SPARSE GMMA 64x128x64 TN F32+=E4M3*E4M3 +// SPARSE GMMA 64x128x64 TN S32+=S8*S8 template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, GMMA::SparseSel spsel = GMMA::SparseSel::Zero > -struct GMMA_64x128x64_F32E4M3E4M3_RS_TN +struct GMMA_64x128x64_S32S8S8_RS_TN_SATURATE { using DRegisters = void; using ARegisters = uint32_t[4]; using ERegisters = uint32_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = float[64]; + using CRegisters = uint32_t[64]; CUTE_HOST_DEVICE static void fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, uint64_t const& desc_b, - float & d00, float & d01, float & d02, float & d03, - float & d04, float & d05, float & d06, float & d07, - float & d08, float & d09, float & d10, float & d11, - float & d12, float & d13, float & d14, float & d15, - float & d16, float & d17, float & d18, float & d19, - float & d20, float & d21, float & d22, float & d23, - float & d24, float & d25, float & d26, float & d27, - float & d28, float & d29, float & d30, float & d31, - float & d32, float & d33, float & d34, float & d35, - float & d36, float & d37, float & d38, float & d39, - float & d40, float & d41, float & d42, float & d43, - float & d44, float & d45, float & d46, float & d47, - float & d48, float & d49, float & d50, float & d51, - float & d52, float & d53, float & d54, float & d55, - float & d56, float & d57, float & d58, float & d59, - float & d60, float & d61, float & d62, float & d63, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, uint32_t const& e, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" "setp.ne.b32 p, %71, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n128k64.f32.e4m3.e4m3 " + "wgmma.mma_async.sp.sync.aligned.m64n128k64.s32.s8.s8.satfinite " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " " %16, %17, %18, %19, %20, %21, %22, %23, " @@ -35653,53 +6624,50 @@ struct GMMA_64x128x64_F32E4M3E4M3_RS_TN "{%64, %65, %66, %67}," " %68," " %69, %70," - " p, %72, %73;\n" + " p;\n" "}\n" - : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), - "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), - "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), - "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), - "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), - "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), - "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), - "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), - "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), - "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), - "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), - "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), - "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), - "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), - "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), - "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63) + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) : "r"(a00), "r"(a01), "r"(a02), "r"(a03), "l"(desc_b), "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); + "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x128x64_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x128x64_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x144x64 TN F16+=E4M3*E4M3 +// SPARSE GMMA 64x192x64 TN S32+=S8*S8 template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, GMMA::SparseSel spsel = GMMA::SparseSel::Zero > -struct GMMA_64x144x64_F16E4M3E4M3_SS_TN +struct GMMA_64x192x64_S32S8S8_RS_TN { using DRegisters = void; - using ARegisters = uint64_t[1]; + using ARegisters = uint32_t[4]; using ERegisters = uint32_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[36]; + using CRegisters = uint32_t[96]; CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, uint64_t const& desc_b, uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, @@ -35710,24 +6678,47 @@ struct GMMA_64x144x64_F16E4M3E4M3_SS_TN uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + uint32_t & d88, uint32_t & d89, uint32_t & d90, uint32_t & d91, + uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95, uint32_t const& e, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %40, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n144k64.f16.e4m3.e4m3 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35}," - " %36," - " %37," - " %38, %39," - " p, %41, %42;\n" + "setp.ne.b32 p, %103, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n192k64.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + "{%96, %97, %98, %99}," + " %100," + " %101, %102," + " p;\n" "}\n" : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), @@ -35737,34 +6728,45 @@ struct GMMA_64x144x64_F16E4M3E4M3_SS_TN "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35) - : "l"(desc_a), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87), + "+r"(d88), "+r"(d89), "+r"(d90), "+r"(d91), + "+r"(d92), "+r"(d93), "+r"(d94), "+r"(d95) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), "l"(desc_b), "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); + "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x144x64_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x192x64_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x144x64 TN F16+=E4M3*E4M3 +// SPARSE GMMA 64x192x64 TN S32+=S8*S8 template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, GMMA::SparseSel spsel = GMMA::SparseSel::Zero > -struct GMMA_64x144x64_F16E4M3E4M3_RS_TN +struct GMMA_64x192x64_S32S8S8_RS_TN_SATURATE { using DRegisters = void; using ARegisters = uint32_t[4]; using ERegisters = uint32_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[36]; + using CRegisters = uint32_t[96]; CUTE_HOST_DEVICE static void fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, @@ -35778,24 +6780,47 @@ struct GMMA_64x144x64_F16E4M3E4M3_RS_TN uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + uint32_t & d88, uint32_t & d89, uint32_t & d90, uint32_t & d91, + uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95, uint32_t const& e, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %43, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n144k64.f16.e4m3.e4m3 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35}," - "{%36, %37, %38, %39}," - " %40," - " %41, %42," - " p, %44, %45;\n" + "setp.ne.b32 p, %103, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n192k64.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + "{%96, %97, %98, %99}," + " %100," + " %101, %102," + " p;\n" "}\n" : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), @@ -35805,65 +6830,91 @@ struct GMMA_64x144x64_F16E4M3E4M3_RS_TN "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35) + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87), + "+r"(d88), "+r"(d89), "+r"(d90), "+r"(d91), + "+r"(d92), "+r"(d93), "+r"(d94), "+r"(d95) : "r"(a00), "r"(a01), "r"(a02), "r"(a03), "l"(desc_b), "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); + "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x144x64_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x192x64_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x144x64 TN F32+=E4M3*E4M3 +// SPARSE GMMA 64x256x64 TN S32+=S8*S8 template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, GMMA::SparseSel spsel = GMMA::SparseSel::Zero > -struct GMMA_64x144x64_F32E4M3E4M3_SS_TN +struct GMMA_64x256x64_S32S8S8_RS_TN { using DRegisters = void; - using ARegisters = uint64_t[1]; + using ARegisters = uint32_t[4]; using ERegisters = uint32_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = float[72]; + using CRegisters = uint32_t[128]; CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, uint64_t const& desc_b, - float & d00, float & d01, float & d02, float & d03, - float & d04, float & d05, float & d06, float & d07, - float & d08, float & d09, float & d10, float & d11, - float & d12, float & d13, float & d14, float & d15, - float & d16, float & d17, float & d18, float & d19, - float & d20, float & d21, float & d22, float & d23, - float & d24, float & d25, float & d26, float & d27, - float & d28, float & d29, float & d30, float & d31, - float & d32, float & d33, float & d34, float & d35, - float & d36, float & d37, float & d38, float & d39, - float & d40, float & d41, float & d42, float & d43, - float & d44, float & d45, float & d46, float & d47, - float & d48, float & d49, float & d50, float & d51, - float & d52, float & d53, float & d54, float & d55, - float & d56, float & d57, float & d58, float & d59, - float & d60, float & d61, float & d62, float & d63, - float & d64, float & d65, float & d66, float & d67, - float & d68, float & d69, float & d70, float & d71, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + uint32_t & d120, uint32_t & d121, uint32_t & d122, uint32_t & d123, + uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127, uint32_t const& e, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %76, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n144k64.f32.e4m3.e4m3 " + "setp.ne.b32 p, %135, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n256k64.s32.s8.s8 " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " " %16, %17, %18, %19, %20, %21, %22, %23, " @@ -35872,88 +6923,120 @@ struct GMMA_64x144x64_F32E4M3E4M3_SS_TN " %40, %41, %42, %43, %44, %45, %46, %47, " " %48, %49, %50, %51, %52, %53, %54, %55, " " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71}," - " %72," - " %73," - " %74, %75," - " p, %77, %78;\n" + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + "{%128, %129, %130, %131}," + " %132," + " %133, %134," + " p;\n" "}\n" - : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), - "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), - "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), - "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), - "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), - "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), - "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), - "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), - "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), - "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), - "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), - "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), - "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), - "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), - "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), - "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), - "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), - "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71) - : "l"(desc_a), + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119), + "+r"(d120), "+r"(d121), "+r"(d122), "+r"(d123), + "+r"(d124), "+r"(d125), "+r"(d126), "+r"(d127) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), "l"(desc_b), "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); + "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x144x64_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x256x64_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x144x64 TN F32+=E4M3*E4M3 +// SPARSE GMMA 64x256x64 TN S32+=S8*S8 template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, GMMA::SparseSel spsel = GMMA::SparseSel::Zero > -struct GMMA_64x144x64_F32E4M3E4M3_RS_TN +struct GMMA_64x256x64_S32S8S8_RS_TN_SATURATE { using DRegisters = void; using ARegisters = uint32_t[4]; using ERegisters = uint32_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = float[72]; + using CRegisters = uint32_t[128]; CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, uint64_t const& desc_b, - float & d00, float & d01, float & d02, float & d03, - float & d04, float & d05, float & d06, float & d07, - float & d08, float & d09, float & d10, float & d11, - float & d12, float & d13, float & d14, float & d15, - float & d16, float & d17, float & d18, float & d19, - float & d20, float & d21, float & d22, float & d23, - float & d24, float & d25, float & d26, float & d27, - float & d28, float & d29, float & d30, float & d31, - float & d32, float & d33, float & d34, float & d35, - float & d36, float & d37, float & d38, float & d39, - float & d40, float & d41, float & d42, float & d43, - float & d44, float & d45, float & d46, float & d47, - float & d48, float & d49, float & d50, float & d51, - float & d52, float & d53, float & d54, float & d55, - float & d56, float & d57, float & d58, float & d59, - float & d60, float & d61, float & d62, float & d63, - float & d64, float & d65, float & d66, float & d67, - float & d68, float & d69, float & d70, float & d71, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + uint32_t & d120, uint32_t & d121, uint32_t & d122, uint32_t & d123, + uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127, uint32_t const& e, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %79, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n144k64.f32.e4m3.e4m3 " + "setp.ne.b32 p, %135, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n256k64.s32.s8.s8.satfinite " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " " %16, %17, %18, %19, %20, %21, %22, %23, " @@ -35962,387 +7045,258 @@ struct GMMA_64x144x64_F32E4M3E4M3_RS_TN " %40, %41, %42, %43, %44, %45, %46, %47, " " %48, %49, %50, %51, %52, %53, %54, %55, " " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71}," - "{%72, %73, %74, %75}," - " %76," - " %77, %78," - " p, %80, %81;\n" + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + "{%128, %129, %130, %131}," + " %132," + " %133, %134," + " p;\n" "}\n" - : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), - "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), - "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), - "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), - "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), - "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), - "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), - "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), - "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), - "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), - "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), - "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), - "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), - "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), - "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), - "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), - "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), - "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119), + "+r"(d120), "+r"(d121), "+r"(d122), "+r"(d123), + "+r"(d124), "+r"(d125), "+r"(d126), "+r"(d127) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), "l"(desc_b), "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); + "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x144x64_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x256x64_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x160x64 TN F16+=E4M3*E4M3 +// SPARSE GMMA 64x8x64 TN S32+=S8*U8 template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, GMMA::SparseSel spsel = GMMA::SparseSel::Zero > -struct GMMA_64x160x64_F16E4M3E4M3_SS_TN +struct GMMA_64x8x64_S32S8U8_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; using ERegisters = uint32_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[40]; + using CRegisters = uint32_t[4]; CUTE_HOST_DEVICE static void fma(uint64_t const& desc_a, uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, uint32_t const& e, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %44, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n160k64.f16.e4m3.e4m3 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39}," - " %40," - " %41," - " %42, %43," - " p, %45, %46;\n" + "setp.ne.b32 p, %8, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n8k64.s32.s8.u8 " + "{%0, %1, %2, %3}," + " %4," + " %5," + " %6, %7," + " p;\n" "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) : "l"(desc_a), "l"(desc_b), "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); + "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x160x64_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x8x64_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x160x64 TN F16+=E4M3*E4M3 +// SPARSE GMMA 64x8x64 TN S32+=S8*U8 template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, GMMA::SparseSel spsel = GMMA::SparseSel::Zero > -struct GMMA_64x160x64_F16E4M3E4M3_RS_TN +struct GMMA_64x8x64_S32S8U8_SS_TN_SATURATE { using DRegisters = void; - using ARegisters = uint32_t[4]; + using ARegisters = uint64_t[1]; using ERegisters = uint32_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[40]; + using CRegisters = uint32_t[4]; CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + fma(uint64_t const& desc_a, uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, uint32_t const& e, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %47, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n160k64.f16.e4m3.e4m3 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39}," - "{%40, %41, %42, %43}," - " %44," - " %45, %46," - " p, %48, %49;\n" + "setp.ne.b32 p, %8, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n8k64.s32.s8.u8.satfinite " + "{%0, %1, %2, %3}," + " %4," + " %5," + " %6, %7," + " p;\n" "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) + : "l"(desc_a), "l"(desc_b), "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); + "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x160x64_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x8x64_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x160x64 TN F32+=E4M3*E4M3 +// SPARSE GMMA 64x16x64 TN S32+=S8*U8 template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, GMMA::SparseSel spsel = GMMA::SparseSel::Zero > -struct GMMA_64x160x64_F32E4M3E4M3_SS_TN +struct GMMA_64x16x64_S32S8U8_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; using ERegisters = uint32_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = float[80]; + using CRegisters = uint32_t[8]; CUTE_HOST_DEVICE static void fma(uint64_t const& desc_a, uint64_t const& desc_b, - float & d00, float & d01, float & d02, float & d03, - float & d04, float & d05, float & d06, float & d07, - float & d08, float & d09, float & d10, float & d11, - float & d12, float & d13, float & d14, float & d15, - float & d16, float & d17, float & d18, float & d19, - float & d20, float & d21, float & d22, float & d23, - float & d24, float & d25, float & d26, float & d27, - float & d28, float & d29, float & d30, float & d31, - float & d32, float & d33, float & d34, float & d35, - float & d36, float & d37, float & d38, float & d39, - float & d40, float & d41, float & d42, float & d43, - float & d44, float & d45, float & d46, float & d47, - float & d48, float & d49, float & d50, float & d51, - float & d52, float & d53, float & d54, float & d55, - float & d56, float & d57, float & d58, float & d59, - float & d60, float & d61, float & d62, float & d63, - float & d64, float & d65, float & d66, float & d67, - float & d68, float & d69, float & d70, float & d71, - float & d72, float & d73, float & d74, float & d75, - float & d76, float & d77, float & d78, float & d79, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, uint32_t const& e, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %84, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n160k64.f32.e4m3.e4m3 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79}," - " %80," - " %81," - " %82, %83," - " p, %85, %86;\n" + "setp.ne.b32 p, %12, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n16k64.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + " %8," + " %9," + " %10, %11," + " p;\n" "}\n" - : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), - "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), - "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), - "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), - "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), - "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), - "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), - "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), - "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), - "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), - "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), - "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), - "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), - "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), - "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), - "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), - "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), - "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), - "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), - "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79) + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) : "l"(desc_a), "l"(desc_b), "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); + "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x160x64_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x16x64_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x160x64 TN F32+=E4M3*E4M3 +// SPARSE GMMA 64x16x64 TN S32+=S8*U8 template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, GMMA::SparseSel spsel = GMMA::SparseSel::Zero > -struct GMMA_64x160x64_F32E4M3E4M3_RS_TN +struct GMMA_64x16x64_S32S8U8_SS_TN_SATURATE { using DRegisters = void; - using ARegisters = uint32_t[4]; + using ARegisters = uint64_t[1]; using ERegisters = uint32_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = float[80]; + using CRegisters = uint32_t[8]; CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + fma(uint64_t const& desc_a, uint64_t const& desc_b, - float & d00, float & d01, float & d02, float & d03, - float & d04, float & d05, float & d06, float & d07, - float & d08, float & d09, float & d10, float & d11, - float & d12, float & d13, float & d14, float & d15, - float & d16, float & d17, float & d18, float & d19, - float & d20, float & d21, float & d22, float & d23, - float & d24, float & d25, float & d26, float & d27, - float & d28, float & d29, float & d30, float & d31, - float & d32, float & d33, float & d34, float & d35, - float & d36, float & d37, float & d38, float & d39, - float & d40, float & d41, float & d42, float & d43, - float & d44, float & d45, float & d46, float & d47, - float & d48, float & d49, float & d50, float & d51, - float & d52, float & d53, float & d54, float & d55, - float & d56, float & d57, float & d58, float & d59, - float & d60, float & d61, float & d62, float & d63, - float & d64, float & d65, float & d66, float & d67, - float & d68, float & d69, float & d70, float & d71, - float & d72, float & d73, float & d74, float & d75, - float & d76, float & d77, float & d78, float & d79, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, uint32_t const& e, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %87, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n160k64.f32.e4m3.e4m3 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79}," - "{%80, %81, %82, %83}," - " %84," - " %85, %86," - " p, %88, %89;\n" - "}\n" - : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), - "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), - "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), - "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), - "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), - "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), - "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), - "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), - "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), - "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), - "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), - "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), - "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), - "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), - "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), - "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), - "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), - "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), - "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), - "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "setp.ne.b32 p, %12, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n16k64.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + " %8," + " %9," + " %10, %11," + " p;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) + : "l"(desc_a), "l"(desc_b), "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); + "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x160x64_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x16x64_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x176x64 TN F16+=E4M3*E4M3 +// SPARSE GMMA 64x32x64 TN S32+=S8*U8 template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, GMMA::SparseSel spsel = GMMA::SparseSel::Zero > -struct GMMA_64x176x64_F16E4M3E4M3_SS_TN +struct GMMA_64x32x64_S32S8U8_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; using ERegisters = uint32_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[44]; + using CRegisters = uint32_t[16]; CUTE_HOST_DEVICE static void fma(uint64_t const& desc_a, @@ -36351,337 +7305,220 @@ struct GMMA_64x176x64_F16E4M3E4M3_SS_TN uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, uint32_t const& e, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %48, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n176k64.f16.e4m3.e4m3 " + "setp.ne.b32 p, %20, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n32k64.s32.s8.u8 " "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43}," - " %44," - " %45," - " %46, %47," - " p, %49, %50;\n" + " %8, %9, %10, %11, %12, %13, %14, %15}," + " %16," + " %17," + " %18, %19," + " p;\n" "}\n" : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43) + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) : "l"(desc_a), "l"(desc_b), "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); + "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x176x64_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x32x64_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x176x64 TN F16+=E4M3*E4M3 +// SPARSE GMMA 64x32x64 TN S32+=S8*U8 template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, GMMA::SparseSel spsel = GMMA::SparseSel::Zero > -struct GMMA_64x176x64_F16E4M3E4M3_RS_TN +struct GMMA_64x32x64_S32S8U8_SS_TN_SATURATE { using DRegisters = void; - using ARegisters = uint32_t[4]; + using ARegisters = uint64_t[1]; using ERegisters = uint32_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[44]; + using CRegisters = uint32_t[16]; CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + fma(uint64_t const& desc_a, uint64_t const& desc_b, uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, uint32_t const& e, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %51, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n176k64.f16.e4m3.e4m3 " + "setp.ne.b32 p, %20, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n32k64.s32.s8.u8.satfinite " "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43}," - "{%44, %45, %46, %47}," - " %48," - " %49, %50," - " p, %52, %53;\n" + " %8, %9, %10, %11, %12, %13, %14, %15}," + " %16," + " %17," + " %18, %19," + " p;\n" "}\n" : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) + : "l"(desc_a), "l"(desc_b), "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); + "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x176x64_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x32x64_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x176x64 TN F32+=E4M3*E4M3 +// SPARSE GMMA 64x64x64 TN S32+=S8*U8 template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, GMMA::SparseSel spsel = GMMA::SparseSel::Zero > -struct GMMA_64x176x64_F32E4M3E4M3_SS_TN +struct GMMA_64x64x64_S32S8U8_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; using ERegisters = uint32_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = float[88]; + using CRegisters = uint32_t[32]; CUTE_HOST_DEVICE static void fma(uint64_t const& desc_a, uint64_t const& desc_b, - float & d00, float & d01, float & d02, float & d03, - float & d04, float & d05, float & d06, float & d07, - float & d08, float & d09, float & d10, float & d11, - float & d12, float & d13, float & d14, float & d15, - float & d16, float & d17, float & d18, float & d19, - float & d20, float & d21, float & d22, float & d23, - float & d24, float & d25, float & d26, float & d27, - float & d28, float & d29, float & d30, float & d31, - float & d32, float & d33, float & d34, float & d35, - float & d36, float & d37, float & d38, float & d39, - float & d40, float & d41, float & d42, float & d43, - float & d44, float & d45, float & d46, float & d47, - float & d48, float & d49, float & d50, float & d51, - float & d52, float & d53, float & d54, float & d55, - float & d56, float & d57, float & d58, float & d59, - float & d60, float & d61, float & d62, float & d63, - float & d64, float & d65, float & d66, float & d67, - float & d68, float & d69, float & d70, float & d71, - float & d72, float & d73, float & d74, float & d75, - float & d76, float & d77, float & d78, float & d79, - float & d80, float & d81, float & d82, float & d83, - float & d84, float & d85, float & d86, float & d87, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, uint32_t const& e, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %92, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n176k64.f32.e4m3.e4m3 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87}," - " %88," - " %89," - " %90, %91," - " p, %93, %94;\n" + "setp.ne.b32 p, %36, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n64k64.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + " %32," + " %33," + " %34, %35," + " p;\n" "}\n" - : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), - "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), - "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), - "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), - "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), - "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), - "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), - "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), - "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), - "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), - "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), - "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), - "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), - "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), - "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), - "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), - "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), - "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), - "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), - "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), - "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), - "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87) + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) : "l"(desc_a), "l"(desc_b), "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); + "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x176x64_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x64x64_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x176x64 TN F32+=E4M3*E4M3 +// SPARSE GMMA 64x64x64 TN S32+=S8*U8 template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, GMMA::SparseSel spsel = GMMA::SparseSel::Zero > -struct GMMA_64x176x64_F32E4M3E4M3_RS_TN +struct GMMA_64x64x64_S32S8U8_SS_TN_SATURATE { using DRegisters = void; - using ARegisters = uint32_t[4]; + using ARegisters = uint64_t[1]; using ERegisters = uint32_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = float[88]; + using CRegisters = uint32_t[32]; CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + fma(uint64_t const& desc_a, uint64_t const& desc_b, - float & d00, float & d01, float & d02, float & d03, - float & d04, float & d05, float & d06, float & d07, - float & d08, float & d09, float & d10, float & d11, - float & d12, float & d13, float & d14, float & d15, - float & d16, float & d17, float & d18, float & d19, - float & d20, float & d21, float & d22, float & d23, - float & d24, float & d25, float & d26, float & d27, - float & d28, float & d29, float & d30, float & d31, - float & d32, float & d33, float & d34, float & d35, - float & d36, float & d37, float & d38, float & d39, - float & d40, float & d41, float & d42, float & d43, - float & d44, float & d45, float & d46, float & d47, - float & d48, float & d49, float & d50, float & d51, - float & d52, float & d53, float & d54, float & d55, - float & d56, float & d57, float & d58, float & d59, - float & d60, float & d61, float & d62, float & d63, - float & d64, float & d65, float & d66, float & d67, - float & d68, float & d69, float & d70, float & d71, - float & d72, float & d73, float & d74, float & d75, - float & d76, float & d77, float & d78, float & d79, - float & d80, float & d81, float & d82, float & d83, - float & d84, float & d85, float & d86, float & d87, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, uint32_t const& e, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %95, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n176k64.f32.e4m3.e4m3 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87}," - "{%88, %89, %90, %91}," - " %92," - " %93, %94," - " p, %96, %97;\n" + "setp.ne.b32 p, %36, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n64k64.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + " %32," + " %33," + " %34, %35," + " p;\n" "}\n" - : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), - "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), - "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), - "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), - "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), - "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), - "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), - "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), - "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), - "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), - "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), - "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), - "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), - "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), - "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), - "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), - "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), - "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), - "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), - "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), - "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), - "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) + : "l"(desc_a), "l"(desc_b), "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); + "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x176x64_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x64x64_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -// SPARSE GMMA 64x192x64 TN F16+=E4M3*E4M3 +// SPARSE GMMA 64x96x64 TN S32+=S8*U8 template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, GMMA::SparseSel spsel = GMMA::SparseSel::Zero > -struct GMMA_64x192x64_F16E4M3E4M3_SS_TN +struct GMMA_64x96x64_S32S8U8_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -36708,11 +7545,12 @@ struct GMMA_64x192x64_F16E4M3E4M3_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" "setp.ne.b32 p, %52, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n192k64.f16.e4m3.e4m3 " + "wgmma.mma_async.sp.sync.aligned.m64n96k64.s32.s8.u8 " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " " %16, %17, %18, %19, %20, %21, %22, %23, " @@ -36722,7 +7560,7 @@ struct GMMA_64x192x64_F16E4M3E4M3_SS_TN " %48," " %49," " %50, %51," - " p, %53, %54;\n" + " p;\n" "}\n" : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), @@ -36739,31 +7577,29 @@ struct GMMA_64x192x64_F16E4M3E4M3_SS_TN : "l"(desc_a), "l"(desc_b), "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); + "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x192x64_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x96x64_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// SPARSE GMMA 64x192x64 TN F16+=E4M3*E4M3 +// SPARSE GMMA 64x96x64 TN S32+=S8*U8 template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, GMMA::SparseSel spsel = GMMA::SparseSel::Zero > -struct GMMA_64x192x64_F16E4M3E4M3_RS_TN +struct GMMA_64x96x64_S32S8U8_SS_TN_SATURATE { using DRegisters = void; - using ARegisters = uint32_t[4]; + using ARegisters = uint64_t[1]; using ERegisters = uint32_t[1]; using BRegisters = uint64_t[1]; using CRegisters = uint32_t[48]; CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + fma(uint64_t const& desc_a, uint64_t const& desc_b, uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, @@ -36781,21 +7617,22 @@ struct GMMA_64x192x64_F16E4M3E4M3_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %55, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n192k64.f16.e4m3.e4m3 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47}," - "{%48, %49, %50, %51}," - " %52," - " %53, %54," - " p, %56, %57;\n" + "setp.ne.b32 p, %52, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n96k64.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + " %48," + " %49," + " %50, %51," + " p;\n" "}\n" : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), @@ -36809,68 +7646,59 @@ struct GMMA_64x192x64_F16E4M3E4M3_RS_TN "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + : "l"(desc_a), "l"(desc_b), "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); + "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x192x64_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x96x64_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// SPARSE GMMA 64x192x64 TN F32+=E4M3*E4M3 +// SPARSE GMMA 64x128x64 TN S32+=S8*U8 template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, GMMA::SparseSel spsel = GMMA::SparseSel::Zero > -struct GMMA_64x192x64_F32E4M3E4M3_SS_TN +struct GMMA_64x128x64_S32S8U8_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; using ERegisters = uint32_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = float[96]; + using CRegisters = uint32_t[64]; CUTE_HOST_DEVICE static void fma(uint64_t const& desc_a, uint64_t const& desc_b, - float & d00, float & d01, float & d02, float & d03, - float & d04, float & d05, float & d06, float & d07, - float & d08, float & d09, float & d10, float & d11, - float & d12, float & d13, float & d14, float & d15, - float & d16, float & d17, float & d18, float & d19, - float & d20, float & d21, float & d22, float & d23, - float & d24, float & d25, float & d26, float & d27, - float & d28, float & d29, float & d30, float & d31, - float & d32, float & d33, float & d34, float & d35, - float & d36, float & d37, float & d38, float & d39, - float & d40, float & d41, float & d42, float & d43, - float & d44, float & d45, float & d46, float & d47, - float & d48, float & d49, float & d50, float & d51, - float & d52, float & d53, float & d54, float & d55, - float & d56, float & d57, float & d58, float & d59, - float & d60, float & d61, float & d62, float & d63, - float & d64, float & d65, float & d66, float & d67, - float & d68, float & d69, float & d70, float & d71, - float & d72, float & d73, float & d74, float & d75, - float & d76, float & d77, float & d78, float & d79, - float & d80, float & d81, float & d82, float & d83, - float & d84, float & d85, float & d86, float & d87, - float & d88, float & d89, float & d90, float & d91, - float & d92, float & d93, float & d94, float & d95, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, uint32_t const& e, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %100, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n192k64.f32.e4m3.e4m3 " + "setp.ne.b32 p, %68, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n128k64.s32.s8.u8 " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " " %16, %17, %18, %19, %20, %21, %22, %23, " @@ -36878,169 +7706,133 @@ struct GMMA_64x192x64_F32E4M3E4M3_SS_TN " %32, %33, %34, %35, %36, %37, %38, %39, " " %40, %41, %42, %43, %44, %45, %46, %47, " " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87, " - " %88, %89, %90, %91, %92, %93, %94, %95}," - " %96," - " %97," - " %98, %99," - " p, %101, %102;\n" + " %56, %57, %58, %59, %60, %61, %62, %63}," + " %64," + " %65," + " %66, %67," + " p;\n" "}\n" - : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), - "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), - "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), - "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), - "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), - "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), - "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), - "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), - "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), - "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), - "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), - "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), - "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), - "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), - "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), - "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), - "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), - "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), - "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), - "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), - "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), - "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), - "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91), - "+f"(d92), "+f"(d93), "+f"(d94), "+f"(d95) + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) : "l"(desc_a), "l"(desc_b), "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); + "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x192x64_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x128x64_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// SPARSE GMMA 64x192x64 TN F32+=E4M3*E4M3 +// SPARSE GMMA 64x128x64 TN S32+=S8*U8 template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, GMMA::SparseSel spsel = GMMA::SparseSel::Zero > -struct GMMA_64x192x64_F32E4M3E4M3_RS_TN +struct GMMA_64x128x64_S32S8U8_SS_TN_SATURATE { using DRegisters = void; - using ARegisters = uint32_t[4]; + using ARegisters = uint64_t[1]; using ERegisters = uint32_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = float[96]; + using CRegisters = uint32_t[64]; CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + fma(uint64_t const& desc_a, uint64_t const& desc_b, - float & d00, float & d01, float & d02, float & d03, - float & d04, float & d05, float & d06, float & d07, - float & d08, float & d09, float & d10, float & d11, - float & d12, float & d13, float & d14, float & d15, - float & d16, float & d17, float & d18, float & d19, - float & d20, float & d21, float & d22, float & d23, - float & d24, float & d25, float & d26, float & d27, - float & d28, float & d29, float & d30, float & d31, - float & d32, float & d33, float & d34, float & d35, - float & d36, float & d37, float & d38, float & d39, - float & d40, float & d41, float & d42, float & d43, - float & d44, float & d45, float & d46, float & d47, - float & d48, float & d49, float & d50, float & d51, - float & d52, float & d53, float & d54, float & d55, - float & d56, float & d57, float & d58, float & d59, - float & d60, float & d61, float & d62, float & d63, - float & d64, float & d65, float & d66, float & d67, - float & d68, float & d69, float & d70, float & d71, - float & d72, float & d73, float & d74, float & d75, - float & d76, float & d77, float & d78, float & d79, - float & d80, float & d81, float & d82, float & d83, - float & d84, float & d85, float & d86, float & d87, - float & d88, float & d89, float & d90, float & d91, - float & d92, float & d93, float & d94, float & d95, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, uint32_t const& e, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %103, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n192k64.f32.e4m3.e4m3 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87, " - " %88, %89, %90, %91, %92, %93, %94, %95}," - "{%96, %97, %98, %99}," - " %100," - " %101, %102," - " p, %104, %105;\n" + "setp.ne.b32 p, %68, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n128k64.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + " %64," + " %65," + " %66, %67," + " p;\n" "}\n" - : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), - "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), - "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), - "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), - "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), - "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), - "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), - "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), - "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), - "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), - "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), - "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), - "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), - "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), - "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), - "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), - "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), - "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), - "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), - "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), - "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), - "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), - "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91), - "+f"(d92), "+f"(d93), "+f"(d94), "+f"(d95) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) + : "l"(desc_a), "l"(desc_b), "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); + "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x192x64_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x128x64_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x208x64 TN F16+=E4M3*E4M3 +// SPARSE GMMA 64x192x64 TN S32+=S8*U8 template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, GMMA::SparseSel spsel = GMMA::SparseSel::Zero > -struct GMMA_64x208x64_F16E4M3E4M3_SS_TN +struct GMMA_64x192x64_S32S8U8_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; using ERegisters = uint32_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[52]; + using CRegisters = uint32_t[96]; CUTE_HOST_DEVICE static void fma(uint64_t const& desc_a, @@ -37058,26 +7850,43 @@ struct GMMA_64x208x64_F16E4M3E4M3_SS_TN uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + uint32_t & d88, uint32_t & d89, uint32_t & d90, uint32_t & d91, + uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95, uint32_t const& e, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %56, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n208k64.f16.e4m3.e4m3 " + "setp.ne.b32 p, %100, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n192k64.s32.s8.u8 " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " " %16, %17, %18, %19, %20, %21, %22, %23, " " %24, %25, %26, %27, %28, %29, %30, %31, " " %32, %33, %34, %35, %36, %37, %38, %39, " " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51}," - " %52," - " %53," - " %54, %55," - " p, %57, %58;\n" + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + " %96," + " %97," + " %98, %99," + " p;\n" "}\n" : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), @@ -37091,37 +7900,44 @@ struct GMMA_64x208x64_F16E4M3E4M3_SS_TN "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), - "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51) + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87), + "+r"(d88), "+r"(d89), "+r"(d90), "+r"(d91), + "+r"(d92), "+r"(d93), "+r"(d94), "+r"(d95) : "l"(desc_a), "l"(desc_b), "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); + "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x208x64_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x192x64_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x208x64 TN F16+=E4M3*E4M3 +// SPARSE GMMA 64x192x64 TN S32+=S8*U8 template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, GMMA::SparseSel spsel = GMMA::SparseSel::Zero > -struct GMMA_64x208x64_F16E4M3E4M3_RS_TN +struct GMMA_64x192x64_S32S8U8_SS_TN_SATURATE { using DRegisters = void; - using ARegisters = uint32_t[4]; + using ARegisters = uint64_t[1]; using ERegisters = uint32_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[52]; + using CRegisters = uint32_t[96]; CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + fma(uint64_t const& desc_a, uint64_t const& desc_b, uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, @@ -37136,26 +7952,43 @@ struct GMMA_64x208x64_F16E4M3E4M3_RS_TN uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + uint32_t & d88, uint32_t & d89, uint32_t & d90, uint32_t & d91, + uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95, uint32_t const& e, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %59, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n208k64.f16.e4m3.e4m3 " + "setp.ne.b32 p, %100, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n192k64.s32.s8.u8.satfinite " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " " %16, %17, %18, %19, %20, %21, %22, %23, " " %24, %25, %26, %27, %28, %29, %30, %31, " " %32, %33, %34, %35, %36, %37, %38, %39, " " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51}," - "{%52, %53, %54, %55}," - " %56," - " %57, %58," - " p, %60, %61;\n" + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + " %96," + " %97," + " %98, %99," + " p;\n" "}\n" : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), @@ -37169,73 +8002,87 @@ struct GMMA_64x208x64_F16E4M3E4M3_RS_TN "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), - "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87), + "+r"(d88), "+r"(d89), "+r"(d90), "+r"(d91), + "+r"(d92), "+r"(d93), "+r"(d94), "+r"(d95) + : "l"(desc_a), "l"(desc_b), "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); + "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x208x64_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x192x64_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x208x64 TN F32+=E4M3*E4M3 +// SPARSE GMMA 64x256x64 TN S32+=S8*U8 template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, GMMA::SparseSel spsel = GMMA::SparseSel::Zero > -struct GMMA_64x208x64_F32E4M3E4M3_SS_TN +struct GMMA_64x256x64_S32S8U8_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; using ERegisters = uint32_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = float[104]; + using CRegisters = uint32_t[128]; CUTE_HOST_DEVICE static void fma(uint64_t const& desc_a, uint64_t const& desc_b, - float & d000, float & d001, float & d002, float & d003, - float & d004, float & d005, float & d006, float & d007, - float & d008, float & d009, float & d010, float & d011, - float & d012, float & d013, float & d014, float & d015, - float & d016, float & d017, float & d018, float & d019, - float & d020, float & d021, float & d022, float & d023, - float & d024, float & d025, float & d026, float & d027, - float & d028, float & d029, float & d030, float & d031, - float & d032, float & d033, float & d034, float & d035, - float & d036, float & d037, float & d038, float & d039, - float & d040, float & d041, float & d042, float & d043, - float & d044, float & d045, float & d046, float & d047, - float & d048, float & d049, float & d050, float & d051, - float & d052, float & d053, float & d054, float & d055, - float & d056, float & d057, float & d058, float & d059, - float & d060, float & d061, float & d062, float & d063, - float & d064, float & d065, float & d066, float & d067, - float & d068, float & d069, float & d070, float & d071, - float & d072, float & d073, float & d074, float & d075, - float & d076, float & d077, float & d078, float & d079, - float & d080, float & d081, float & d082, float & d083, - float & d084, float & d085, float & d086, float & d087, - float & d088, float & d089, float & d090, float & d091, - float & d092, float & d093, float & d094, float & d095, - float & d096, float & d097, float & d098, float & d099, - float & d100, float & d101, float & d102, float & d103, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + uint32_t & d120, uint32_t & d121, uint32_t & d122, uint32_t & d123, + uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127, uint32_t const& e, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %108, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n208k64.f32.e4m3.e4m3 " + "setp.ne.b32 p, %132, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n256k64.s32.s8.u8 " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " " %16, %17, %18, %19, %20, %21, %22, %23, " @@ -37248,104 +8095,116 @@ struct GMMA_64x208x64_F32E4M3E4M3_SS_TN " %72, %73, %74, %75, %76, %77, %78, %79, " " %80, %81, %82, %83, %84, %85, %86, %87, " " %88, %89, %90, %91, %92, %93, %94, %95, " - " %96, %97, %98, %99, %100, %101, %102, %103}," - " %104," - " %105," - " %106, %107," - " p, %109, %110;\n" + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + " %128," + " %129," + " %130, %131," + " p;\n" "}\n" - : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), - "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), - "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), - "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), - "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), - "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), - "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), - "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), - "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), - "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), - "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), - "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), - "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), - "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), - "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), - "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), - "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), - "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), - "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), - "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), - "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), - "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), - "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), - "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), - "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), - "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103) + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119), + "+r"(d120), "+r"(d121), "+r"(d122), "+r"(d123), + "+r"(d124), "+r"(d125), "+r"(d126), "+r"(d127) : "l"(desc_a), "l"(desc_b), "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); + "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x208x64_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x256x64_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x208x64 TN F32+=E4M3*E4M3 +// SPARSE GMMA 64x256x64 TN S32+=S8*U8 template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, GMMA::SparseSel spsel = GMMA::SparseSel::Zero > -struct GMMA_64x208x64_F32E4M3E4M3_RS_TN +struct GMMA_64x256x64_S32S8U8_SS_TN_SATURATE { using DRegisters = void; - using ARegisters = uint32_t[4]; + using ARegisters = uint64_t[1]; using ERegisters = uint32_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = float[104]; + using CRegisters = uint32_t[128]; CUTE_HOST_DEVICE static void - fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + fma(uint64_t const& desc_a, uint64_t const& desc_b, - float & d000, float & d001, float & d002, float & d003, - float & d004, float & d005, float & d006, float & d007, - float & d008, float & d009, float & d010, float & d011, - float & d012, float & d013, float & d014, float & d015, - float & d016, float & d017, float & d018, float & d019, - float & d020, float & d021, float & d022, float & d023, - float & d024, float & d025, float & d026, float & d027, - float & d028, float & d029, float & d030, float & d031, - float & d032, float & d033, float & d034, float & d035, - float & d036, float & d037, float & d038, float & d039, - float & d040, float & d041, float & d042, float & d043, - float & d044, float & d045, float & d046, float & d047, - float & d048, float & d049, float & d050, float & d051, - float & d052, float & d053, float & d054, float & d055, - float & d056, float & d057, float & d058, float & d059, - float & d060, float & d061, float & d062, float & d063, - float & d064, float & d065, float & d066, float & d067, - float & d068, float & d069, float & d070, float & d071, - float & d072, float & d073, float & d074, float & d075, - float & d076, float & d077, float & d078, float & d079, - float & d080, float & d081, float & d082, float & d083, - float & d084, float & d085, float & d086, float & d087, - float & d088, float & d089, float & d090, float & d091, - float & d092, float & d093, float & d094, float & d095, - float & d096, float & d097, float & d098, float & d099, - float & d100, float & d101, float & d102, float & d103, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + uint32_t & d120, uint32_t & d121, uint32_t & d122, uint32_t & d123, + uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127, uint32_t const& e, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %111, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n208k64.f32.e4m3.e4m3 " + "setp.ne.b32 p, %132, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n256k64.s32.s8.u8.satfinite " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " " %16, %17, %18, %19, %20, %21, %22, %23, " @@ -37358,458 +8217,361 @@ struct GMMA_64x208x64_F32E4M3E4M3_RS_TN " %72, %73, %74, %75, %76, %77, %78, %79, " " %80, %81, %82, %83, %84, %85, %86, %87, " " %88, %89, %90, %91, %92, %93, %94, %95, " - " %96, %97, %98, %99, %100, %101, %102, %103}," - "{%104, %105, %106, %107}," - " %108," - " %109, %110," - " p, %112, %113;\n" + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + " %128," + " %129," + " %130, %131," + " p;\n" "}\n" - : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), - "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), - "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), - "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), - "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), - "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), - "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), - "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), - "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), - "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), - "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), - "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), - "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), - "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), - "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), - "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), - "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), - "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), - "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), - "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), - "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), - "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), - "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), - "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), - "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), - "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103) - : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119), + "+r"(d120), "+r"(d121), "+r"(d122), "+r"(d123), + "+r"(d124), "+r"(d125), "+r"(d126), "+r"(d127) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x256x64_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x8x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x8x64_S32S8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %11, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n8k64.s32.s8.u8 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + " %8," + " %9, %10," + " p;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x8x64_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x8x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x8x64_S32S8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %11, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n8k64.s32.s8.u8.satfinite " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + " %8," + " %9, %10," + " p;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), "l"(desc_b), "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); + "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x208x64_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x8x64_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x224x64 TN F16+=E4M3*E4M3 +// SPARSE GMMA 64x16x64 TN S32+=S8*U8 template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, GMMA::SparseSel spsel = GMMA::SparseSel::Zero > -struct GMMA_64x224x64_F16E4M3E4M3_SS_TN +struct GMMA_64x16x64_S32S8U8_RS_TN { using DRegisters = void; - using ARegisters = uint64_t[1]; + using ARegisters = uint32_t[4]; using ERegisters = uint32_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[56]; + using CRegisters = uint32_t[8]; CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, - uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, - uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, uint32_t const& e, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %60, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n224k64.f16.e4m3.e4m3 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55}," - " %56," - " %57," - " %58, %59," - " p, %61, %62;\n" + "setp.ne.b32 p, %15, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n16k64.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + "{%8, %9, %10, %11}," + " %12," + " %13, %14," + " p;\n" "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), - "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), - "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), - "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) - : "l"(desc_a), + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), "l"(desc_b), "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); + "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x224x64_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x16x64_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x224x64 TN F16+=E4M3*E4M3 +// SPARSE GMMA 64x16x64 TN S32+=S8*U8 template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, GMMA::SparseSel spsel = GMMA::SparseSel::Zero > -struct GMMA_64x224x64_F16E4M3E4M3_RS_TN +struct GMMA_64x16x64_S32S8U8_RS_TN_SATURATE { using DRegisters = void; using ARegisters = uint32_t[4]; using ERegisters = uint32_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[56]; + using CRegisters = uint32_t[8]; CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, - uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, - uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, uint32_t const& e, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %63, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n224k64.f16.e4m3.e4m3 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55}," - "{%56, %57, %58, %59}," - " %60," - " %61, %62," - " p, %64, %65;\n" + "setp.ne.b32 p, %15, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n16k64.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + "{%8, %9, %10, %11}," + " %12," + " %13, %14," + " p;\n" "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), - "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), - "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), - "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), "l"(desc_b), "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); + "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x224x64_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x16x64_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x224x64 TN F32+=E4M3*E4M3 +// SPARSE GMMA 64x32x64 TN S32+=S8*U8 template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, GMMA::SparseSel spsel = GMMA::SparseSel::Zero > -struct GMMA_64x224x64_F32E4M3E4M3_SS_TN +struct GMMA_64x32x64_S32S8U8_RS_TN { using DRegisters = void; - using ARegisters = uint64_t[1]; + using ARegisters = uint32_t[4]; using ERegisters = uint32_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = float[112]; + using CRegisters = uint32_t[16]; CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, uint64_t const& desc_b, - float & d000, float & d001, float & d002, float & d003, - float & d004, float & d005, float & d006, float & d007, - float & d008, float & d009, float & d010, float & d011, - float & d012, float & d013, float & d014, float & d015, - float & d016, float & d017, float & d018, float & d019, - float & d020, float & d021, float & d022, float & d023, - float & d024, float & d025, float & d026, float & d027, - float & d028, float & d029, float & d030, float & d031, - float & d032, float & d033, float & d034, float & d035, - float & d036, float & d037, float & d038, float & d039, - float & d040, float & d041, float & d042, float & d043, - float & d044, float & d045, float & d046, float & d047, - float & d048, float & d049, float & d050, float & d051, - float & d052, float & d053, float & d054, float & d055, - float & d056, float & d057, float & d058, float & d059, - float & d060, float & d061, float & d062, float & d063, - float & d064, float & d065, float & d066, float & d067, - float & d068, float & d069, float & d070, float & d071, - float & d072, float & d073, float & d074, float & d075, - float & d076, float & d077, float & d078, float & d079, - float & d080, float & d081, float & d082, float & d083, - float & d084, float & d085, float & d086, float & d087, - float & d088, float & d089, float & d090, float & d091, - float & d092, float & d093, float & d094, float & d095, - float & d096, float & d097, float & d098, float & d099, - float & d100, float & d101, float & d102, float & d103, - float & d104, float & d105, float & d106, float & d107, - float & d108, float & d109, float & d110, float & d111, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, uint32_t const& e, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %116, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n224k64.f32.e4m3.e4m3 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87, " - " %88, %89, %90, %91, %92, %93, %94, %95, " - " %96, %97, %98, %99, %100, %101, %102, %103, " - " %104, %105, %106, %107, %108, %109, %110, %111}," - " %112," - " %113," - " %114, %115," - " p, %117, %118;\n" + "setp.ne.b32 p, %23, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n32k64.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + "{%16, %17, %18, %19}," + " %20," + " %21, %22," + " p;\n" "}\n" - : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), - "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), - "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), - "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), - "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), - "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), - "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), - "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), - "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), - "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), - "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), - "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), - "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), - "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), - "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), - "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), - "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), - "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), - "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), - "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), - "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), - "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), - "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), - "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), - "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), - "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), - "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), - "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111) - : "l"(desc_a), + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), "l"(desc_b), "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); + "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x224x64_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x32x64_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x224x64 TN F32+=E4M3*E4M3 +// SPARSE GMMA 64x32x64 TN S32+=S8*U8 template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, GMMA::SparseSel spsel = GMMA::SparseSel::Zero > -struct GMMA_64x224x64_F32E4M3E4M3_RS_TN +struct GMMA_64x32x64_S32S8U8_RS_TN_SATURATE { using DRegisters = void; using ARegisters = uint32_t[4]; using ERegisters = uint32_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = float[112]; + using CRegisters = uint32_t[16]; CUTE_HOST_DEVICE static void - fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, uint64_t const& desc_b, - float & d000, float & d001, float & d002, float & d003, - float & d004, float & d005, float & d006, float & d007, - float & d008, float & d009, float & d010, float & d011, - float & d012, float & d013, float & d014, float & d015, - float & d016, float & d017, float & d018, float & d019, - float & d020, float & d021, float & d022, float & d023, - float & d024, float & d025, float & d026, float & d027, - float & d028, float & d029, float & d030, float & d031, - float & d032, float & d033, float & d034, float & d035, - float & d036, float & d037, float & d038, float & d039, - float & d040, float & d041, float & d042, float & d043, - float & d044, float & d045, float & d046, float & d047, - float & d048, float & d049, float & d050, float & d051, - float & d052, float & d053, float & d054, float & d055, - float & d056, float & d057, float & d058, float & d059, - float & d060, float & d061, float & d062, float & d063, - float & d064, float & d065, float & d066, float & d067, - float & d068, float & d069, float & d070, float & d071, - float & d072, float & d073, float & d074, float & d075, - float & d076, float & d077, float & d078, float & d079, - float & d080, float & d081, float & d082, float & d083, - float & d084, float & d085, float & d086, float & d087, - float & d088, float & d089, float & d090, float & d091, - float & d092, float & d093, float & d094, float & d095, - float & d096, float & d097, float & d098, float & d099, - float & d100, float & d101, float & d102, float & d103, - float & d104, float & d105, float & d106, float & d107, - float & d108, float & d109, float & d110, float & d111, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, uint32_t const& e, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %119, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n224k64.f32.e4m3.e4m3 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87, " - " %88, %89, %90, %91, %92, %93, %94, %95, " - " %96, %97, %98, %99, %100, %101, %102, %103, " - " %104, %105, %106, %107, %108, %109, %110, %111}," - "{%112, %113, %114, %115}," - " %116," - " %117, %118," - " p, %120, %121;\n" + "setp.ne.b32 p, %23, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n32k64.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + "{%16, %17, %18, %19}," + " %20," + " %21, %22," + " p;\n" "}\n" - : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), - "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), - "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), - "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), - "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), - "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), - "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), - "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), - "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), - "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), - "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), - "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), - "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), - "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), - "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), - "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), - "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), - "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), - "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), - "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), - "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), - "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), - "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), - "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), - "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), - "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), - "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), - "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111) - : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), "l"(desc_b), "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); + "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x224x64_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x32x64_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x240x64 TN F16+=E4M3*E4M3 +// SPARSE GMMA 64x64x64 TN S32+=S8*U8 template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, GMMA::SparseSel spsel = GMMA::SparseSel::Zero > -struct GMMA_64x240x64_F16E4M3E4M3_SS_TN +struct GMMA_64x64x64_S32S8U8_RS_TN { using DRegisters = void; - using ARegisters = uint64_t[1]; + using ARegisters = uint32_t[4]; using ERegisters = uint32_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[60]; + using CRegisters = uint32_t[32]; CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, uint64_t const& desc_b, uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, @@ -37819,34 +8581,24 @@ struct GMMA_64x240x64_F16E4M3E4M3_SS_TN uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, - uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, - uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, - uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, uint32_t const& e, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %64, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n240k64.f16.e4m3.e4m3 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59}," - " %60," - " %61," - " %62, %63," - " p, %65, %66;\n" + "setp.ne.b32 p, %39, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n64k64.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + "{%32, %33, %34, %35}," + " %36," + " %37, %38," + " p;\n" "}\n" : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), @@ -37855,41 +8607,30 @@ struct GMMA_64x240x64_F16E4M3E4M3_SS_TN "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), - "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), - "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), - "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), - "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59) - : "l"(desc_a), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), "l"(desc_b), "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); + "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x240x64_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x64x64_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x240x64 TN F16+=E4M3*E4M3 +// SPARSE GMMA 64x64x64 TN S32+=S8*U8 template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, GMMA::SparseSel spsel = GMMA::SparseSel::Zero > -struct GMMA_64x240x64_F16E4M3E4M3_RS_TN +struct GMMA_64x64x64_S32S8U8_RS_TN_SATURATE { using DRegisters = void; using ARegisters = uint32_t[4]; using ERegisters = uint32_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[60]; + using CRegisters = uint32_t[32]; CUTE_HOST_DEVICE static void fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, @@ -37902,34 +8643,24 @@ struct GMMA_64x240x64_F16E4M3E4M3_RS_TN uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, - uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, - uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, - uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, uint32_t const& e, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %67, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n240k64.f16.e4m3.e4m3 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59}," - "{%60, %61, %62, %63}," - " %64," - " %65, %66," - " p, %68, %69;\n" + "setp.ne.b32 p, %39, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n64k64.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + "{%32, %33, %34, %35}," + " %36," + " %37, %38," + " p;\n" "}\n" : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), @@ -37938,283 +8669,177 @@ struct GMMA_64x240x64_F16E4M3E4M3_RS_TN "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), - "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), - "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), - "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), - "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59) + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) : "r"(a00), "r"(a01), "r"(a02), "r"(a03), "l"(desc_b), "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); + "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x240x64_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x64x64_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x240x64 TN F32+=E4M3*E4M3 +// SPARSE GMMA 64x96x64 TN S32+=S8*U8 template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, GMMA::SparseSel spsel = GMMA::SparseSel::Zero > -struct GMMA_64x240x64_F32E4M3E4M3_SS_TN +struct GMMA_64x96x64_S32S8U8_RS_TN { using DRegisters = void; - using ARegisters = uint64_t[1]; + using ARegisters = uint32_t[4]; using ERegisters = uint32_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = float[120]; + using CRegisters = uint32_t[48]; CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, uint64_t const& desc_b, - float & d000, float & d001, float & d002, float & d003, - float & d004, float & d005, float & d006, float & d007, - float & d008, float & d009, float & d010, float & d011, - float & d012, float & d013, float & d014, float & d015, - float & d016, float & d017, float & d018, float & d019, - float & d020, float & d021, float & d022, float & d023, - float & d024, float & d025, float & d026, float & d027, - float & d028, float & d029, float & d030, float & d031, - float & d032, float & d033, float & d034, float & d035, - float & d036, float & d037, float & d038, float & d039, - float & d040, float & d041, float & d042, float & d043, - float & d044, float & d045, float & d046, float & d047, - float & d048, float & d049, float & d050, float & d051, - float & d052, float & d053, float & d054, float & d055, - float & d056, float & d057, float & d058, float & d059, - float & d060, float & d061, float & d062, float & d063, - float & d064, float & d065, float & d066, float & d067, - float & d068, float & d069, float & d070, float & d071, - float & d072, float & d073, float & d074, float & d075, - float & d076, float & d077, float & d078, float & d079, - float & d080, float & d081, float & d082, float & d083, - float & d084, float & d085, float & d086, float & d087, - float & d088, float & d089, float & d090, float & d091, - float & d092, float & d093, float & d094, float & d095, - float & d096, float & d097, float & d098, float & d099, - float & d100, float & d101, float & d102, float & d103, - float & d104, float & d105, float & d106, float & d107, - float & d108, float & d109, float & d110, float & d111, - float & d112, float & d113, float & d114, float & d115, - float & d116, float & d117, float & d118, float & d119, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, uint32_t const& e, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %124, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n240k64.f32.e4m3.e4m3 " + "setp.ne.b32 p, %55, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n96k64.s32.s8.u8 " "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87, " - " %88, %89, %90, %91, %92, %93, %94, %95, " - " %96, %97, %98, %99, %100, %101, %102, %103, " - " %104, %105, %106, %107, %108, %109, %110, %111, " - " %112, %113, %114, %115, %116, %117, %118, %119}," - " %120," - " %121," - " %122, %123," - " p, %125, %126;\n" - "}\n" - : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), - "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), - "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), - "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), - "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), - "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), - "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), - "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), - "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), - "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), - "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), - "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), - "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), - "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), - "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), - "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), - "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), - "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), - "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), - "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), - "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), - "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), - "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), - "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), - "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), - "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), - "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), - "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), - "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), - "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119) - : "l"(desc_a), + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + "{%48, %49, %50, %51}," + " %52," + " %53, %54," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), "l"(desc_b), "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); + "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x240x64_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x96x64_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x240x64 TN F32+=E4M3*E4M3 +// SPARSE GMMA 64x96x64 TN S32+=S8*U8 template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, GMMA::SparseSel spsel = GMMA::SparseSel::Zero > -struct GMMA_64x240x64_F32E4M3E4M3_RS_TN +struct GMMA_64x96x64_S32S8U8_RS_TN_SATURATE { using DRegisters = void; using ARegisters = uint32_t[4]; using ERegisters = uint32_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = float[120]; + using CRegisters = uint32_t[48]; CUTE_HOST_DEVICE static void - fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, uint64_t const& desc_b, - float & d000, float & d001, float & d002, float & d003, - float & d004, float & d005, float & d006, float & d007, - float & d008, float & d009, float & d010, float & d011, - float & d012, float & d013, float & d014, float & d015, - float & d016, float & d017, float & d018, float & d019, - float & d020, float & d021, float & d022, float & d023, - float & d024, float & d025, float & d026, float & d027, - float & d028, float & d029, float & d030, float & d031, - float & d032, float & d033, float & d034, float & d035, - float & d036, float & d037, float & d038, float & d039, - float & d040, float & d041, float & d042, float & d043, - float & d044, float & d045, float & d046, float & d047, - float & d048, float & d049, float & d050, float & d051, - float & d052, float & d053, float & d054, float & d055, - float & d056, float & d057, float & d058, float & d059, - float & d060, float & d061, float & d062, float & d063, - float & d064, float & d065, float & d066, float & d067, - float & d068, float & d069, float & d070, float & d071, - float & d072, float & d073, float & d074, float & d075, - float & d076, float & d077, float & d078, float & d079, - float & d080, float & d081, float & d082, float & d083, - float & d084, float & d085, float & d086, float & d087, - float & d088, float & d089, float & d090, float & d091, - float & d092, float & d093, float & d094, float & d095, - float & d096, float & d097, float & d098, float & d099, - float & d100, float & d101, float & d102, float & d103, - float & d104, float & d105, float & d106, float & d107, - float & d108, float & d109, float & d110, float & d111, - float & d112, float & d113, float & d114, float & d115, - float & d116, float & d117, float & d118, float & d119, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, uint32_t const& e, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %127, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n240k64.f32.e4m3.e4m3 " + "setp.ne.b32 p, %55, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n96k64.s32.s8.u8.satfinite " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " " %16, %17, %18, %19, %20, %21, %22, %23, " " %24, %25, %26, %27, %28, %29, %30, %31, " " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87, " - " %88, %89, %90, %91, %92, %93, %94, %95, " - " %96, %97, %98, %99, %100, %101, %102, %103, " - " %104, %105, %106, %107, %108, %109, %110, %111, " - " %112, %113, %114, %115, %116, %117, %118, %119}," - "{%120, %121, %122, %123}," - " %124," - " %125, %126," - " p, %128, %129;\n" + " %40, %41, %42, %43, %44, %45, %46, %47}," + "{%48, %49, %50, %51}," + " %52," + " %53, %54," + " p;\n" "}\n" - : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), - "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), - "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), - "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), - "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), - "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), - "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), - "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), - "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), - "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), - "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), - "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), - "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), - "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), - "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), - "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), - "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), - "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), - "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), - "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), - "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), - "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), - "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), - "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), - "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), - "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), - "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), - "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), - "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), - "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119) - : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), "l"(desc_b), "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); + "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x240x64_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x96x64_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -// SPARSE GMMA 64x256x64 TN F16+=E4M3*E4M3 +// SPARSE GMMA 64x128x64 TN S32+=S8*U8 template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, GMMA::SparseSel spsel = GMMA::SparseSel::Zero > -struct GMMA_64x256x64_F16E4M3E4M3_SS_TN +struct GMMA_64x128x64_S32S8U8_RS_TN { using DRegisters = void; - using ARegisters = uint64_t[1]; + using ARegisters = uint32_t[4]; using ERegisters = uint32_t[1]; using BRegisters = uint64_t[1]; using CRegisters = uint32_t[64]; CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, uint64_t const& desc_b, uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, @@ -38236,11 +8861,12 @@ struct GMMA_64x256x64_F16E4M3E4M3_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %68, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n256k64.f16.e4m3.e4m3 " + "setp.ne.b32 p, %71, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n128k64.s32.s8.u8 " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " " %16, %17, %18, %19, %20, %21, %22, %23, " @@ -38249,10 +8875,10 @@ struct GMMA_64x256x64_F16E4M3E4M3_SS_TN " %40, %41, %42, %43, %44, %45, %46, %47, " " %48, %49, %50, %51, %52, %53, %54, %55, " " %56, %57, %58, %59, %60, %61, %62, %63}," - " %64," - " %65," - " %66, %67," - " p, %69, %70;\n" + "{%64, %65, %66, %67}," + " %68," + " %69, %70," + " p;\n" "}\n" : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), @@ -38270,25 +8896,23 @@ struct GMMA_64x256x64_F16E4M3E4M3_SS_TN "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) - : "l"(desc_a), + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), "l"(desc_b), "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); + "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x256x64_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x128x64_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// SPARSE GMMA 64x256x64 TN F16+=E4M3*E4M3 +// SPARSE GMMA 64x128x64 TN S32+=S8*U8 template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, GMMA::SparseSel spsel = GMMA::SparseSel::Zero > -struct GMMA_64x256x64_F16E4M3E4M3_RS_TN +struct GMMA_64x128x64_S32S8U8_RS_TN_SATURATE { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -38319,11 +8943,12 @@ struct GMMA_64x256x64_F16E4M3E4M3_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" "setp.ne.b32 p, %71, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n256k64.f16.e4m3.e4m3 " + "wgmma.mma_async.sp.sync.aligned.m64n128k64.s32.s8.u8.satfinite " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " " %16, %17, %18, %19, %20, %21, %22, %23, " @@ -38335,7 +8960,7 @@ struct GMMA_64x256x64_F16E4M3E4M3_RS_TN "{%64, %65, %66, %67}," " %68," " %69, %70," - " p, %72, %73;\n" + " p;\n" "}\n" : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), @@ -38356,73 +8981,64 @@ struct GMMA_64x256x64_F16E4M3E4M3_RS_TN : "r"(a00), "r"(a01), "r"(a02), "r"(a03), "l"(desc_b), "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); + "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x256x64_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x128x64_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// SPARSE GMMA 64x256x64 TN F32+=E4M3*E4M3 +// SPARSE GMMA 64x192x64 TN S32+=S8*U8 template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, GMMA::SparseSel spsel = GMMA::SparseSel::Zero > -struct GMMA_64x256x64_F32E4M3E4M3_SS_TN +struct GMMA_64x192x64_S32S8U8_RS_TN { using DRegisters = void; - using ARegisters = uint64_t[1]; + using ARegisters = uint32_t[4]; using ERegisters = uint32_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = float[128]; + using CRegisters = uint32_t[96]; CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, uint64_t const& desc_b, - float & d000, float & d001, float & d002, float & d003, - float & d004, float & d005, float & d006, float & d007, - float & d008, float & d009, float & d010, float & d011, - float & d012, float & d013, float & d014, float & d015, - float & d016, float & d017, float & d018, float & d019, - float & d020, float & d021, float & d022, float & d023, - float & d024, float & d025, float & d026, float & d027, - float & d028, float & d029, float & d030, float & d031, - float & d032, float & d033, float & d034, float & d035, - float & d036, float & d037, float & d038, float & d039, - float & d040, float & d041, float & d042, float & d043, - float & d044, float & d045, float & d046, float & d047, - float & d048, float & d049, float & d050, float & d051, - float & d052, float & d053, float & d054, float & d055, - float & d056, float & d057, float & d058, float & d059, - float & d060, float & d061, float & d062, float & d063, - float & d064, float & d065, float & d066, float & d067, - float & d068, float & d069, float & d070, float & d071, - float & d072, float & d073, float & d074, float & d075, - float & d076, float & d077, float & d078, float & d079, - float & d080, float & d081, float & d082, float & d083, - float & d084, float & d085, float & d086, float & d087, - float & d088, float & d089, float & d090, float & d091, - float & d092, float & d093, float & d094, float & d095, - float & d096, float & d097, float & d098, float & d099, - float & d100, float & d101, float & d102, float & d103, - float & d104, float & d105, float & d106, float & d107, - float & d108, float & d109, float & d110, float & d111, - float & d112, float & d113, float & d114, float & d115, - float & d116, float & d117, float & d118, float & d119, - float & d120, float & d121, float & d122, float & d123, - float & d124, float & d125, float & d126, float & d127, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + uint32_t & d88, uint32_t & d89, uint32_t & d90, uint32_t & d91, + uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95, uint32_t const& e, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %132, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n256k64.f32.e4m3.e4m3 " + "setp.ne.b32 p, %103, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n192k64.s32.s8.u8 " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " " %16, %17, %18, %19, %20, %21, %22, %23, " @@ -38434,118 +9050,97 @@ struct GMMA_64x256x64_F32E4M3E4M3_SS_TN " %64, %65, %66, %67, %68, %69, %70, %71, " " %72, %73, %74, %75, %76, %77, %78, %79, " " %80, %81, %82, %83, %84, %85, %86, %87, " - " %88, %89, %90, %91, %92, %93, %94, %95, " - " %96, %97, %98, %99, %100, %101, %102, %103, " - " %104, %105, %106, %107, %108, %109, %110, %111, " - " %112, %113, %114, %115, %116, %117, %118, %119, " - " %120, %121, %122, %123, %124, %125, %126, %127}," - " %128," - " %129," - " %130, %131," - " p, %133, %134;\n" + " %88, %89, %90, %91, %92, %93, %94, %95}," + "{%96, %97, %98, %99}," + " %100," + " %101, %102," + " p;\n" "}\n" - : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), - "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), - "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), - "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), - "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), - "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), - "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), - "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), - "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), - "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), - "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), - "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), - "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), - "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), - "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), - "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), - "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), - "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), - "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), - "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), - "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), - "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), - "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), - "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), - "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), - "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), - "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), - "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), - "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), - "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), - "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123), - "+f"(d124), "+f"(d125), "+f"(d126), "+f"(d127) - : "l"(desc_a), + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87), + "+r"(d88), "+r"(d89), "+r"(d90), "+r"(d91), + "+r"(d92), "+r"(d93), "+r"(d94), "+r"(d95) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), "l"(desc_b), "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); + "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x256x64_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x192x64_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// SPARSE GMMA 64x256x64 TN F32+=E4M3*E4M3 +// SPARSE GMMA 64x192x64 TN S32+=S8*U8 template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, GMMA::SparseSel spsel = GMMA::SparseSel::Zero > -struct GMMA_64x256x64_F32E4M3E4M3_RS_TN +struct GMMA_64x192x64_S32S8U8_RS_TN_SATURATE { using DRegisters = void; using ARegisters = uint32_t[4]; using ERegisters = uint32_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = float[128]; + using CRegisters = uint32_t[96]; CUTE_HOST_DEVICE static void - fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, uint64_t const& desc_b, - float & d000, float & d001, float & d002, float & d003, - float & d004, float & d005, float & d006, float & d007, - float & d008, float & d009, float & d010, float & d011, - float & d012, float & d013, float & d014, float & d015, - float & d016, float & d017, float & d018, float & d019, - float & d020, float & d021, float & d022, float & d023, - float & d024, float & d025, float & d026, float & d027, - float & d028, float & d029, float & d030, float & d031, - float & d032, float & d033, float & d034, float & d035, - float & d036, float & d037, float & d038, float & d039, - float & d040, float & d041, float & d042, float & d043, - float & d044, float & d045, float & d046, float & d047, - float & d048, float & d049, float & d050, float & d051, - float & d052, float & d053, float & d054, float & d055, - float & d056, float & d057, float & d058, float & d059, - float & d060, float & d061, float & d062, float & d063, - float & d064, float & d065, float & d066, float & d067, - float & d068, float & d069, float & d070, float & d071, - float & d072, float & d073, float & d074, float & d075, - float & d076, float & d077, float & d078, float & d079, - float & d080, float & d081, float & d082, float & d083, - float & d084, float & d085, float & d086, float & d087, - float & d088, float & d089, float & d090, float & d091, - float & d092, float & d093, float & d094, float & d095, - float & d096, float & d097, float & d098, float & d099, - float & d100, float & d101, float & d102, float & d103, - float & d104, float & d105, float & d106, float & d107, - float & d108, float & d109, float & d110, float & d111, - float & d112, float & d113, float & d114, float & d115, - float & d116, float & d117, float & d118, float & d119, - float & d120, float & d121, float & d122, float & d123, - float & d124, float & d125, float & d126, float & d127, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + uint32_t & d88, uint32_t & d89, uint32_t & d90, uint32_t & d91, + uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95, uint32_t const& e, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %135, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n256k64.f32.e4m3.e4m3 " + "setp.ne.b32 p, %103, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n192k64.s32.s8.u8.satfinite " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " " %16, %17, %18, %19, %20, %21, %22, %23, " @@ -38557,251 +9152,342 @@ struct GMMA_64x256x64_F32E4M3E4M3_RS_TN " %64, %65, %66, %67, %68, %69, %70, %71, " " %72, %73, %74, %75, %76, %77, %78, %79, " " %80, %81, %82, %83, %84, %85, %86, %87, " - " %88, %89, %90, %91, %92, %93, %94, %95, " - " %96, %97, %98, %99, %100, %101, %102, %103, " - " %104, %105, %106, %107, %108, %109, %110, %111, " - " %112, %113, %114, %115, %116, %117, %118, %119, " - " %120, %121, %122, %123, %124, %125, %126, %127}," - "{%128, %129, %130, %131}," - " %132," - " %133, %134," - " p, %136, %137;\n" - "}\n" - : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), - "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), - "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), - "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), - "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), - "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), - "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), - "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), - "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), - "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), - "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), - "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), - "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), - "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), - "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), - "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), - "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), - "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), - "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), - "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), - "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), - "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), - "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), - "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), - "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), - "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), - "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), - "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), - "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), - "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), - "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123), - "+f"(d124), "+f"(d125), "+f"(d126), "+f"(d127) - : "r"(a000), "r"(a001), "r"(a002), "r"(a003), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x256x64_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// SPARSE GMMA 64x8x64 TN F16+=E4M3*E5M2 -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x8x64_F16E4M3E5M2_SS_TN -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[2]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - uint32_t & d0, uint32_t & d1, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %6, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n8k64.f16.e4m3.e5m2 " - "{%0, %1}," - " %2," - " %3," - " %4, %5," - " p, %7, %8;\n" + " %88, %89, %90, %91, %92, %93, %94, %95}," + "{%96, %97, %98, %99}," + " %100," + " %101, %102," + " p;\n" "}\n" - : "+r"(d0), "+r"(d1) - : "l"(desc_a), + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87), + "+r"(d88), "+r"(d89), "+r"(d90), "+r"(d91), + "+r"(d92), "+r"(d93), "+r"(d94), "+r"(d95) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), "l"(desc_b), "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); + "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x8x64_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x192x64_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// SPARSE GMMA 64x8x64 TN F16+=E4M3*E5M2 +// SPARSE GMMA 64x256x64 TN S32+=S8*U8 template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, GMMA::SparseSel spsel = GMMA::SparseSel::Zero > -struct GMMA_64x8x64_F16E4M3E5M2_RS_TN +struct GMMA_64x256x64_S32S8U8_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; using ERegisters = uint32_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[2]; + using CRegisters = uint32_t[128]; CUTE_HOST_DEVICE static void - fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, uint64_t const& desc_b, - uint32_t & d0, uint32_t & d1, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + uint32_t & d120, uint32_t & d121, uint32_t & d122, uint32_t & d123, + uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127, uint32_t const& e, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %9, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n8k64.f16.e4m3.e5m2 " - "{%0, %1}," - "{%2, %3, %4, %5}," - " %6," - " %7, %8," - " p, %10, %11;\n" + "setp.ne.b32 p, %135, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n256k64.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + "{%128, %129, %130, %131}," + " %132," + " %133, %134," + " p;\n" "}\n" - : "+r"(d0), "+r"(d1) - : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119), + "+r"(d120), "+r"(d121), "+r"(d122), "+r"(d123), + "+r"(d124), "+r"(d125), "+r"(d126), "+r"(d127) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), "l"(desc_b), "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); + "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x8x64_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x256x64_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// SPARSE GMMA 64x8x64 TN F32+=E4M3*E5M2 +// SPARSE GMMA 64x256x64 TN S32+=S8*U8 template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, GMMA::SparseSel spsel = GMMA::SparseSel::Zero > -struct GMMA_64x8x64_F32E4M3E5M2_SS_TN +struct GMMA_64x256x64_S32S8U8_RS_TN_SATURATE { using DRegisters = void; - using ARegisters = uint64_t[1]; + using ARegisters = uint32_t[4]; using ERegisters = uint32_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = float[4]; + using CRegisters = uint32_t[128]; CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, uint64_t const& desc_b, - float & d0, float & d1, float & d2, float & d3, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + uint32_t & d120, uint32_t & d121, uint32_t & d122, uint32_t & d123, + uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127, uint32_t const& e, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %8, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n8k64.f32.e4m3.e5m2 " - "{%0, %1, %2, %3}," - " %4," - " %5," - " %6, %7," - " p, %9, %10;\n" + "setp.ne.b32 p, %135, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n256k64.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + "{%128, %129, %130, %131}," + " %132," + " %133, %134," + " p;\n" "}\n" - : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3) - : "l"(desc_a), + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119), + "+r"(d120), "+r"(d121), "+r"(d122), "+r"(d123), + "+r"(d124), "+r"(d125), "+r"(d126), "+r"(d127) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), "l"(desc_b), "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); + "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x8x64_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x256x64_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// SPARSE GMMA 64x8x64 TN F32+=E4M3*E5M2 +// SPARSE GMMA 64x8x64 TN S32+=U8*S8 template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, GMMA::SparseSel spsel = GMMA::SparseSel::Zero > -struct GMMA_64x8x64_F32E4M3E5M2_RS_TN +struct GMMA_64x8x64_S32U8S8_SS_TN { using DRegisters = void; - using ARegisters = uint32_t[4]; + using ARegisters = uint64_t[1]; using ERegisters = uint32_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = float[4]; + using CRegisters = uint32_t[4]; CUTE_HOST_DEVICE static void - fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + fma(uint64_t const& desc_a, uint64_t const& desc_b, - float & d0, float & d1, float & d2, float & d3, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, uint32_t const& e, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %11, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n8k64.f32.e4m3.e5m2 " + "setp.ne.b32 p, %8, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n8k64.s32.u8.s8 " "{%0, %1, %2, %3}," - "{%4, %5, %6, %7}," - " %8," - " %9, %10," - " p, %12, %13;\n" + " %4," + " %5," + " %6, %7," + " p;\n" "}\n" - : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3) - : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) + : "l"(desc_a), "l"(desc_b), "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); + "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x8x64_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x8x64_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// SPARSE GMMA 64x16x64 TN F16+=E4M3*E5M2 +// SPARSE GMMA 64x8x64 TN S32+=U8*S8 template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, GMMA::SparseSel spsel = GMMA::SparseSel::Zero > -struct GMMA_64x16x64_F16E4M3E5M2_SS_TN +struct GMMA_64x8x64_S32U8S8_SS_TN_SATURATE { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -38817,388 +9503,436 @@ struct GMMA_64x16x64_F16E4M3E5M2_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" "setp.ne.b32 p, %8, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n16k64.f16.e4m3.e5m2 " + "wgmma.mma_async.sp.sync.aligned.m64n8k64.s32.u8.s8.satfinite " "{%0, %1, %2, %3}," " %4," " %5," " %6, %7," - " p, %9, %10;\n" + " p;\n" "}\n" : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) : "l"(desc_a), "l"(desc_b), "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); + "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x16x64_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x8x64_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// SPARSE GMMA 64x16x64 TN F16+=E4M3*E5M2 +// SPARSE GMMA 64x16x64 TN S32+=U8*S8 template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, GMMA::SparseSel spsel = GMMA::SparseSel::Zero > -struct GMMA_64x16x64_F16E4M3E5M2_RS_TN +struct GMMA_64x16x64_S32U8S8_SS_TN { using DRegisters = void; - using ARegisters = uint32_t[4]; + using ARegisters = uint64_t[1]; using ERegisters = uint32_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[4]; + using CRegisters = uint32_t[8]; CUTE_HOST_DEVICE static void - fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + fma(uint64_t const& desc_a, uint64_t const& desc_b, uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, uint32_t const& e, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %11, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n16k64.f16.e4m3.e5m2 " - "{%0, %1, %2, %3}," - "{%4, %5, %6, %7}," + "setp.ne.b32 p, %12, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n16k64.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," " %8," - " %9, %10," - " p, %12, %13;\n" + " %9," + " %10, %11," + " p;\n" "}\n" - : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) - : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) + : "l"(desc_a), "l"(desc_b), "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); + "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x16x64_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x16x64_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// SPARSE GMMA 64x16x64 TN F32+=E4M3*E5M2 +// SPARSE GMMA 64x16x64 TN S32+=U8*S8 template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, GMMA::SparseSel spsel = GMMA::SparseSel::Zero > -struct GMMA_64x16x64_F32E4M3E5M2_SS_TN +struct GMMA_64x16x64_S32U8S8_SS_TN_SATURATE { using DRegisters = void; using ARegisters = uint64_t[1]; using ERegisters = uint32_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = float[8]; + using CRegisters = uint32_t[8]; CUTE_HOST_DEVICE static void fma(uint64_t const& desc_a, uint64_t const& desc_b, - float & d0, float & d1, float & d2, float & d3, - float & d4, float & d5, float & d6, float & d7, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, uint32_t const& e, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" "setp.ne.b32 p, %12, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n16k64.f32.e4m3.e5m2 " + "wgmma.mma_async.sp.sync.aligned.m64n16k64.s32.u8.s8.satfinite " "{%0, %1, %2, %3, %4, %5, %6, %7}," " %8," " %9," " %10, %11," - " p, %13, %14;\n" + " p;\n" "}\n" - : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3), - "+f"(d4), "+f"(d5), "+f"(d6), "+f"(d7) + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) : "l"(desc_a), "l"(desc_b), "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); + "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x16x64_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x16x64_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// SPARSE GMMA 64x16x64 TN F32+=E4M3*E5M2 +// SPARSE GMMA 64x32x64 TN S32+=U8*S8 template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, GMMA::SparseSel spsel = GMMA::SparseSel::Zero > -struct GMMA_64x16x64_F32E4M3E5M2_RS_TN +struct GMMA_64x32x64_S32U8S8_SS_TN { using DRegisters = void; - using ARegisters = uint32_t[4]; + using ARegisters = uint64_t[1]; using ERegisters = uint32_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = float[8]; + using CRegisters = uint32_t[16]; CUTE_HOST_DEVICE static void - fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + fma(uint64_t const& desc_a, uint64_t const& desc_b, - float & d0, float & d1, float & d2, float & d3, - float & d4, float & d5, float & d6, float & d7, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, uint32_t const& e, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %15, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n16k64.f32.e4m3.e5m2 " - "{%0, %1, %2, %3, %4, %5, %6, %7}," - "{%8, %9, %10, %11}," - " %12," - " %13, %14," - " p, %16, %17;\n" + "setp.ne.b32 p, %20, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n32k64.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + " %16," + " %17," + " %18, %19," + " p;\n" "}\n" - : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3), - "+f"(d4), "+f"(d5), "+f"(d6), "+f"(d7) - : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) + : "l"(desc_a), "l"(desc_b), "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); + "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x16x64_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x32x64_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// SPARSE GMMA 64x32x64 TN F16+=E4M3*E5M2 +// SPARSE GMMA 64x32x64 TN S32+=U8*S8 template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, GMMA::SparseSel spsel = GMMA::SparseSel::Zero > -struct GMMA_64x32x64_F16E4M3E5M2_SS_TN +struct GMMA_64x32x64_S32U8S8_SS_TN_SATURATE { using DRegisters = void; using ARegisters = uint64_t[1]; using ERegisters = uint32_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[8]; + using CRegisters = uint32_t[16]; CUTE_HOST_DEVICE static void fma(uint64_t const& desc_a, uint64_t const& desc_b, - uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, - uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, uint32_t const& e, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %12, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n32k64.f16.e4m3.e5m2 " - "{%0, %1, %2, %3, %4, %5, %6, %7}," - " %8," - " %9," - " %10, %11," - " p, %13, %14;\n" + "setp.ne.b32 p, %20, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n32k64.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + " %16," + " %17," + " %18, %19," + " p;\n" "}\n" - : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), - "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) : "l"(desc_a), "l"(desc_b), "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); + "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x32x64_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x32x64_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// SPARSE GMMA 64x32x64 TN F16+=E4M3*E5M2 +// SPARSE GMMA 64x64x64 TN S32+=U8*S8 template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, GMMA::SparseSel spsel = GMMA::SparseSel::Zero > -struct GMMA_64x32x64_F16E4M3E5M2_RS_TN +struct GMMA_64x64x64_S32U8S8_SS_TN { using DRegisters = void; - using ARegisters = uint32_t[4]; + using ARegisters = uint64_t[1]; using ERegisters = uint32_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[8]; + using CRegisters = uint32_t[32]; CUTE_HOST_DEVICE static void - fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + fma(uint64_t const& desc_a, uint64_t const& desc_b, - uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, - uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, uint32_t const& e, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %15, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n32k64.f16.e4m3.e5m2 " - "{%0, %1, %2, %3, %4, %5, %6, %7}," - "{%8, %9, %10, %11}," - " %12," - " %13, %14," - " p, %16, %17;\n" + "setp.ne.b32 p, %36, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n64k64.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + " %32," + " %33," + " %34, %35," + " p;\n" "}\n" - : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), - "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) - : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) + : "l"(desc_a), "l"(desc_b), "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); + "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x32x64_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x64x64_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// SPARSE GMMA 64x32x64 TN F32+=E4M3*E5M2 +// SPARSE GMMA 64x64x64 TN S32+=U8*S8 template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, GMMA::SparseSel spsel = GMMA::SparseSel::Zero > -struct GMMA_64x32x64_F32E4M3E5M2_SS_TN +struct GMMA_64x64x64_S32U8S8_SS_TN_SATURATE { using DRegisters = void; using ARegisters = uint64_t[1]; using ERegisters = uint32_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = float[16]; + using CRegisters = uint32_t[32]; CUTE_HOST_DEVICE static void fma(uint64_t const& desc_a, uint64_t const& desc_b, - float & d00, float & d01, float & d02, float & d03, - float & d04, float & d05, float & d06, float & d07, - float & d08, float & d09, float & d10, float & d11, - float & d12, float & d13, float & d14, float & d15, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, uint32_t const& e, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %20, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n32k64.f32.e4m3.e5m2 " + "setp.ne.b32 p, %36, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n64k64.s32.u8.s8.satfinite " "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15}," - " %16," - " %17," - " %18, %19," - " p, %21, %22;\n" + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + " %32," + " %33," + " %34, %35," + " p;\n" "}\n" - : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), - "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), - "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), - "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15) + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) : "l"(desc_a), "l"(desc_b), "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); + "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x32x64_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x64x64_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// SPARSE GMMA 64x32x64 TN F32+=E4M3*E5M2 +// SPARSE GMMA 64x96x64 TN S32+=U8*S8 template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, GMMA::SparseSel spsel = GMMA::SparseSel::Zero > -struct GMMA_64x32x64_F32E4M3E5M2_RS_TN +struct GMMA_64x96x64_S32U8S8_SS_TN { using DRegisters = void; - using ARegisters = uint32_t[4]; + using ARegisters = uint64_t[1]; using ERegisters = uint32_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = float[16]; + using CRegisters = uint32_t[48]; CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + fma(uint64_t const& desc_a, uint64_t const& desc_b, - float & d00, float & d01, float & d02, float & d03, - float & d04, float & d05, float & d06, float & d07, - float & d08, float & d09, float & d10, float & d11, - float & d12, float & d13, float & d14, float & d15, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, uint32_t const& e, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %23, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n32k64.f32.e4m3.e5m2 " + "setp.ne.b32 p, %52, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n96k64.s32.u8.s8 " "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15}," - "{%16, %17, %18, %19}," - " %20," - " %21, %22," - " p, %24, %25;\n" + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + " %48," + " %49," + " %50, %51," + " p;\n" "}\n" - : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), - "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), - "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), - "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) + : "l"(desc_a), "l"(desc_b), "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); + "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x32x64_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x96x64_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x48x64 TN F16+=E4M3*E5M2 +// SPARSE GMMA 64x96x64 TN S32+=U8*S8 template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, GMMA::SparseSel spsel = GMMA::SparseSel::Zero > -struct GMMA_64x48x64_F16E4M3E5M2_SS_TN +struct GMMA_64x96x64_S32U8S8_SS_TN_SATURATE { using DRegisters = void; using ARegisters = uint64_t[1]; using ERegisters = uint32_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[12]; + using CRegisters = uint32_t[48]; CUTE_HOST_DEVICE static void fma(uint64_t const& desc_a, @@ -39206,224 +9940,337 @@ struct GMMA_64x48x64_F16E4M3E5M2_SS_TN uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, uint32_t const& e, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %16, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n48k64.f16.e4m3.e5m2 " + "setp.ne.b32 p, %52, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n96k64.s32.u8.s8.satfinite " "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11}," - " %12," - " %13," - " %14, %15," - " p, %17, %18;\n" + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + " %48," + " %49," + " %50, %51," + " p;\n" "}\n" : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11) + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) : "l"(desc_a), "l"(desc_b), "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); + "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x48x64_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x96x64_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x48x64 TN F16+=E4M3*E5M2 +// SPARSE GMMA 64x128x64 TN S32+=U8*S8 template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, GMMA::SparseSel spsel = GMMA::SparseSel::Zero > -struct GMMA_64x48x64_F16E4M3E5M2_RS_TN +struct GMMA_64x128x64_S32U8S8_SS_TN { using DRegisters = void; - using ARegisters = uint32_t[4]; + using ARegisters = uint64_t[1]; using ERegisters = uint32_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[12]; + using CRegisters = uint32_t[64]; CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + fma(uint64_t const& desc_a, uint64_t const& desc_b, uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, uint32_t const& e, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %19, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n48k64.f16.e4m3.e5m2 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11}," - "{%12, %13, %14, %15}," - " %16," - " %17, %18," - " p, %20, %21;\n" + "setp.ne.b32 p, %68, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n128k64.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + " %64," + " %65," + " %66, %67," + " p;\n" "}\n" : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) + : "l"(desc_a), "l"(desc_b), "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); + "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x48x64_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x128x64_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x48x64 TN F32+=E4M3*E5M2 +// SPARSE GMMA 64x128x64 TN S32+=U8*S8 template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, GMMA::SparseSel spsel = GMMA::SparseSel::Zero > -struct GMMA_64x48x64_F32E4M3E5M2_SS_TN +struct GMMA_64x128x64_S32U8S8_SS_TN_SATURATE { using DRegisters = void; using ARegisters = uint64_t[1]; using ERegisters = uint32_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = float[24]; + using CRegisters = uint32_t[64]; CUTE_HOST_DEVICE static void fma(uint64_t const& desc_a, uint64_t const& desc_b, - float & d00, float & d01, float & d02, float & d03, - float & d04, float & d05, float & d06, float & d07, - float & d08, float & d09, float & d10, float & d11, - float & d12, float & d13, float & d14, float & d15, - float & d16, float & d17, float & d18, float & d19, - float & d20, float & d21, float & d22, float & d23, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, uint32_t const& e, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %28, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n48k64.f32.e4m3.e5m2 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23}," - " %24," - " %25," - " %26, %27," - " p, %29, %30;\n" + "setp.ne.b32 p, %68, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n128k64.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + " %64," + " %65," + " %66, %67," + " p;\n" "}\n" - : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), - "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), - "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), - "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), - "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), - "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23) + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) : "l"(desc_a), "l"(desc_b), "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); + "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x48x64_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x128x64_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x48x64 TN F32+=E4M3*E5M2 +// SPARSE GMMA 64x192x64 TN S32+=U8*S8 template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, GMMA::SparseSel spsel = GMMA::SparseSel::Zero > -struct GMMA_64x48x64_F32E4M3E5M2_RS_TN +struct GMMA_64x192x64_S32U8S8_SS_TN { using DRegisters = void; - using ARegisters = uint32_t[4]; + using ARegisters = uint64_t[1]; using ERegisters = uint32_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = float[24]; + using CRegisters = uint32_t[96]; CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + fma(uint64_t const& desc_a, uint64_t const& desc_b, - float & d00, float & d01, float & d02, float & d03, - float & d04, float & d05, float & d06, float & d07, - float & d08, float & d09, float & d10, float & d11, - float & d12, float & d13, float & d14, float & d15, - float & d16, float & d17, float & d18, float & d19, - float & d20, float & d21, float & d22, float & d23, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + uint32_t & d88, uint32_t & d89, uint32_t & d90, uint32_t & d91, + uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95, uint32_t const& e, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %31, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n48k64.f32.e4m3.e5m2 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23}," - "{%24, %25, %26, %27}," - " %28," - " %29, %30," - " p, %32, %33;\n" + "setp.ne.b32 p, %100, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n192k64.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + " %96," + " %97," + " %98, %99," + " p;\n" "}\n" - : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), - "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), - "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), - "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), - "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), - "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87), + "+r"(d88), "+r"(d89), "+r"(d90), "+r"(d91), + "+r"(d92), "+r"(d93), "+r"(d94), "+r"(d95) + : "l"(desc_a), "l"(desc_b), "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); + "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x48x64_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x192x64_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -// SPARSE GMMA 64x64x64 TN F16+=E4M3*E5M2 +// SPARSE GMMA 64x192x64 TN S32+=U8*S8 template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, GMMA::SparseSel spsel = GMMA::SparseSel::Zero > -struct GMMA_64x64x64_F16E4M3E5M2_SS_TN +struct GMMA_64x192x64_S32U8S8_SS_TN_SATURATE { using DRegisters = void; using ARegisters = uint64_t[1]; using ERegisters = uint32_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[16]; + using CRegisters = uint32_t[96]; CUTE_HOST_DEVICE static void fma(uint64_t const& desc_a, @@ -39432,544 +10279,632 @@ struct GMMA_64x64x64_F16E4M3E5M2_SS_TN uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + uint32_t & d88, uint32_t & d89, uint32_t & d90, uint32_t & d91, + uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95, uint32_t const& e, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %20, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n64k64.f16.e4m3.e5m2 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15}," - " %16," - " %17," - " %18, %19," - " p, %21, %22;\n" + "setp.ne.b32 p, %100, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n192k64.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + " %96," + " %97," + " %98, %99," + " p;\n" "}\n" : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87), + "+r"(d88), "+r"(d89), "+r"(d90), "+r"(d91), + "+r"(d92), "+r"(d93), "+r"(d94), "+r"(d95) : "l"(desc_a), "l"(desc_b), "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); + "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x64x64_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x192x64_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// SPARSE GMMA 64x64x64 TN F16+=E4M3*E5M2 +// SPARSE GMMA 64x256x64 TN S32+=U8*S8 template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, GMMA::SparseSel spsel = GMMA::SparseSel::Zero > -struct GMMA_64x64x64_F16E4M3E5M2_RS_TN +struct GMMA_64x256x64_S32U8S8_SS_TN { using DRegisters = void; - using ARegisters = uint32_t[4]; + using ARegisters = uint64_t[1]; using ERegisters = uint32_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[16]; + using CRegisters = uint32_t[128]; CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + fma(uint64_t const& desc_a, uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + uint32_t & d120, uint32_t & d121, uint32_t & d122, uint32_t & d123, + uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127, uint32_t const& e, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %23, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n64k64.f16.e4m3.e5m2 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15}," - "{%16, %17, %18, %19}," - " %20," - " %21, %22," - " p, %24, %25;\n" + "setp.ne.b32 p, %132, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n256k64.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + " %128," + " %129," + " %130, %131," + " p;\n" "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119), + "+r"(d120), "+r"(d121), "+r"(d122), "+r"(d123), + "+r"(d124), "+r"(d125), "+r"(d126), "+r"(d127) + : "l"(desc_a), "l"(desc_b), "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); + "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x64x64_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x256x64_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// SPARSE GMMA 64x64x64 TN F32+=E4M3*E5M2 +// SPARSE GMMA 64x256x64 TN S32+=U8*S8 template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, GMMA::SparseSel spsel = GMMA::SparseSel::Zero > -struct GMMA_64x64x64_F32E4M3E5M2_SS_TN +struct GMMA_64x256x64_S32U8S8_SS_TN_SATURATE { using DRegisters = void; using ARegisters = uint64_t[1]; using ERegisters = uint32_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = float[32]; + using CRegisters = uint32_t[128]; CUTE_HOST_DEVICE static void fma(uint64_t const& desc_a, uint64_t const& desc_b, - float & d00, float & d01, float & d02, float & d03, - float & d04, float & d05, float & d06, float & d07, - float & d08, float & d09, float & d10, float & d11, - float & d12, float & d13, float & d14, float & d15, - float & d16, float & d17, float & d18, float & d19, - float & d20, float & d21, float & d22, float & d23, - float & d24, float & d25, float & d26, float & d27, - float & d28, float & d29, float & d30, float & d31, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + uint32_t & d120, uint32_t & d121, uint32_t & d122, uint32_t & d123, + uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127, uint32_t const& e, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %36, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n64k64.f32.e4m3.e5m2 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31}," - " %32," - " %33," - " %34, %35," - " p, %37, %38;\n" + "setp.ne.b32 p, %132, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n256k64.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + " %128," + " %129," + " %130, %131," + " p;\n" "}\n" - : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), - "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), - "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), - "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), - "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), - "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), - "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), - "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31) + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119), + "+r"(d120), "+r"(d121), "+r"(d122), "+r"(d123), + "+r"(d124), "+r"(d125), "+r"(d126), "+r"(d127) : "l"(desc_a), "l"(desc_b), "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); + "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x64x64_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x256x64_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// SPARSE GMMA 64x64x64 TN F32+=E4M3*E5M2 +// SPARSE GMMA 64x8x64 TN S32+=U8*S8 template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, GMMA::SparseSel spsel = GMMA::SparseSel::Zero > -struct GMMA_64x64x64_F32E4M3E5M2_RS_TN +struct GMMA_64x8x64_S32U8S8_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; using ERegisters = uint32_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = float[32]; + using CRegisters = uint32_t[4]; CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, uint64_t const& desc_b, - float & d00, float & d01, float & d02, float & d03, - float & d04, float & d05, float & d06, float & d07, - float & d08, float & d09, float & d10, float & d11, - float & d12, float & d13, float & d14, float & d15, - float & d16, float & d17, float & d18, float & d19, - float & d20, float & d21, float & d22, float & d23, - float & d24, float & d25, float & d26, float & d27, - float & d28, float & d29, float & d30, float & d31, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, uint32_t const& e, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %39, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n64k64.f32.e4m3.e5m2 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31}," - "{%32, %33, %34, %35}," - " %36," - " %37, %38," - " p, %40, %41;\n" + "setp.ne.b32 p, %11, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n8k64.s32.u8.s8 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + " %8," + " %9, %10," + " p;\n" "}\n" - : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), - "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), - "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), - "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), - "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), - "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), - "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), - "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), "l"(desc_b), "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); + "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x64x64_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x8x64_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x80x64 TN F16+=E4M3*E5M2 +// SPARSE GMMA 64x8x64 TN S32+=U8*S8 template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, GMMA::SparseSel spsel = GMMA::SparseSel::Zero > -struct GMMA_64x80x64_F16E4M3E5M2_SS_TN +struct GMMA_64x8x64_S32U8S8_RS_TN_SATURATE { using DRegisters = void; - using ARegisters = uint64_t[1]; + using ARegisters = uint32_t[4]; using ERegisters = uint32_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[20]; + using CRegisters = uint32_t[4]; CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, uint32_t const& e, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %24, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n80k64.f16.e4m3.e5m2 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19}," - " %20," - " %21," - " %22, %23," - " p, %25, %26;\n" + "setp.ne.b32 p, %11, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n8k64.s32.u8.s8.satfinite " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + " %8," + " %9, %10," + " p;\n" "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19) - : "l"(desc_a), + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), "l"(desc_b), "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); + "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x80x64_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x8x64_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x80x64 TN F16+=E4M3*E5M2 +// SPARSE GMMA 64x16x64 TN S32+=U8*S8 template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, GMMA::SparseSel spsel = GMMA::SparseSel::Zero > -struct GMMA_64x80x64_F16E4M3E5M2_RS_TN +struct GMMA_64x16x64_S32U8S8_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; using ERegisters = uint32_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[20]; + using CRegisters = uint32_t[8]; CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, uint32_t const& e, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %27, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n80k64.f16.e4m3.e5m2 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19}," - "{%20, %21, %22, %23}," - " %24," - " %25, %26," - " p, %28, %29;\n" + "setp.ne.b32 p, %15, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n16k64.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + "{%8, %9, %10, %11}," + " %12," + " %13, %14," + " p;\n" "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), "l"(desc_b), "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); + "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x80x64_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x16x64_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x80x64 TN F32+=E4M3*E5M2 +// SPARSE GMMA 64x16x64 TN S32+=U8*S8 template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, GMMA::SparseSel spsel = GMMA::SparseSel::Zero > -struct GMMA_64x80x64_F32E4M3E5M2_SS_TN +struct GMMA_64x16x64_S32U8S8_RS_TN_SATURATE { using DRegisters = void; - using ARegisters = uint64_t[1]; + using ARegisters = uint32_t[4]; using ERegisters = uint32_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = float[40]; + using CRegisters = uint32_t[8]; CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, uint64_t const& desc_b, - float & d00, float & d01, float & d02, float & d03, - float & d04, float & d05, float & d06, float & d07, - float & d08, float & d09, float & d10, float & d11, - float & d12, float & d13, float & d14, float & d15, - float & d16, float & d17, float & d18, float & d19, - float & d20, float & d21, float & d22, float & d23, - float & d24, float & d25, float & d26, float & d27, - float & d28, float & d29, float & d30, float & d31, - float & d32, float & d33, float & d34, float & d35, - float & d36, float & d37, float & d38, float & d39, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, uint32_t const& e, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %44, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n80k64.f32.e4m3.e5m2 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39}," - " %40," - " %41," - " %42, %43," - " p, %45, %46;\n" + "setp.ne.b32 p, %15, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n16k64.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + "{%8, %9, %10, %11}," + " %12," + " %13, %14," + " p;\n" "}\n" - : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), - "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), - "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), - "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), - "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), - "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), - "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), - "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), - "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), - "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39) - : "l"(desc_a), + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), "l"(desc_b), "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); + "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x80x64_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x16x64_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x80x64 TN F32+=E4M3*E5M2 +// SPARSE GMMA 64x32x64 TN S32+=U8*S8 template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, GMMA::SparseSel spsel = GMMA::SparseSel::Zero > -struct GMMA_64x80x64_F32E4M3E5M2_RS_TN +struct GMMA_64x32x64_S32U8S8_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; using ERegisters = uint32_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = float[40]; + using CRegisters = uint32_t[16]; CUTE_HOST_DEVICE static void fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, uint64_t const& desc_b, - float & d00, float & d01, float & d02, float & d03, - float & d04, float & d05, float & d06, float & d07, - float & d08, float & d09, float & d10, float & d11, - float & d12, float & d13, float & d14, float & d15, - float & d16, float & d17, float & d18, float & d19, - float & d20, float & d21, float & d22, float & d23, - float & d24, float & d25, float & d26, float & d27, - float & d28, float & d29, float & d30, float & d31, - float & d32, float & d33, float & d34, float & d35, - float & d36, float & d37, float & d38, float & d39, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, uint32_t const& e, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %47, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n80k64.f32.e4m3.e5m2 " + "setp.ne.b32 p, %23, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n32k64.s32.u8.s8 " "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39}," - "{%40, %41, %42, %43}," - " %44," - " %45, %46," - " p, %48, %49;\n" + " %8, %9, %10, %11, %12, %13, %14, %15}," + "{%16, %17, %18, %19}," + " %20," + " %21, %22," + " p;\n" "}\n" - : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), - "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), - "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), - "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), - "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), - "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), - "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), - "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), - "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), - "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39) + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) : "r"(a00), "r"(a01), "r"(a02), "r"(a03), "l"(desc_b), "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); + "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x80x64_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x32x64_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -// SPARSE GMMA 64x96x64 TN F16+=E4M3*E5M2 +// SPARSE GMMA 64x32x64 TN S32+=U8*S8 template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, GMMA::SparseSel spsel = GMMA::SparseSel::Zero > -struct GMMA_64x96x64_F16E4M3E5M2_SS_TN +struct GMMA_64x32x64_S32U8S8_RS_TN_SATURATE { using DRegisters = void; - using ARegisters = uint64_t[1]; + using ARegisters = uint32_t[4]; using ERegisters = uint32_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[24]; + using CRegisters = uint32_t[16]; CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, uint64_t const& desc_b, uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, uint32_t const& e, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %28, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n96k64.f16.e4m3.e5m2 " + "setp.ne.b32 p, %23, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n32k64.s32.u8.s8.satfinite " "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23}," - " %24," - " %25," - " %26, %27," - " p, %29, %30;\n" + " %8, %9, %10, %11, %12, %13, %14, %15}," + "{%16, %17, %18, %19}," + " %20," + " %21, %22," + " p;\n" "}\n" : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) - : "l"(desc_a), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), "l"(desc_b), "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); + "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x96x64_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x32x64_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// SPARSE GMMA 64x96x64 TN F16+=E4M3*E5M2 +// SPARSE GMMA 64x64x64 TN S32+=U8*S8 template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, GMMA::SparseSel spsel = GMMA::SparseSel::Zero > -struct GMMA_64x96x64_F16E4M3E5M2_RS_TN +struct GMMA_64x64x64_S32U8S8_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; using ERegisters = uint32_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[24]; + using CRegisters = uint32_t[32]; CUTE_HOST_DEVICE static void fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, @@ -39980,152 +10915,146 @@ struct GMMA_64x96x64_F16E4M3E5M2_RS_TN uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, uint32_t const& e, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %31, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n96k64.f16.e4m3.e5m2 " + "setp.ne.b32 p, %39, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n64k64.s32.u8.s8 " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23}," - "{%24, %25, %26, %27}," - " %28," - " %29, %30," - " p, %32, %33;\n" + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + "{%32, %33, %34, %35}," + " %36," + " %37, %38," + " p;\n" "}\n" : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) : "r"(a00), "r"(a01), "r"(a02), "r"(a03), "l"(desc_b), "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); + "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x96x64_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x64x64_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// SPARSE GMMA 64x96x64 TN F32+=E4M3*E5M2 +// SPARSE GMMA 64x64x64 TN S32+=U8*S8 template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, GMMA::SparseSel spsel = GMMA::SparseSel::Zero > -struct GMMA_64x96x64_F32E4M3E5M2_SS_TN +struct GMMA_64x64x64_S32U8S8_RS_TN_SATURATE { using DRegisters = void; - using ARegisters = uint64_t[1]; + using ARegisters = uint32_t[4]; using ERegisters = uint32_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = float[48]; + using CRegisters = uint32_t[32]; CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, uint64_t const& desc_b, - float & d00, float & d01, float & d02, float & d03, - float & d04, float & d05, float & d06, float & d07, - float & d08, float & d09, float & d10, float & d11, - float & d12, float & d13, float & d14, float & d15, - float & d16, float & d17, float & d18, float & d19, - float & d20, float & d21, float & d22, float & d23, - float & d24, float & d25, float & d26, float & d27, - float & d28, float & d29, float & d30, float & d31, - float & d32, float & d33, float & d34, float & d35, - float & d36, float & d37, float & d38, float & d39, - float & d40, float & d41, float & d42, float & d43, - float & d44, float & d45, float & d46, float & d47, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, uint32_t const& e, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %52, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n96k64.f32.e4m3.e5m2 " + "setp.ne.b32 p, %39, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n64k64.s32.u8.s8.satfinite " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47}," - " %48," - " %49," - " %50, %51," - " p, %53, %54;\n" + " %24, %25, %26, %27, %28, %29, %30, %31}," + "{%32, %33, %34, %35}," + " %36," + " %37, %38," + " p;\n" "}\n" - : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), - "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), - "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), - "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), - "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), - "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), - "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), - "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), - "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), - "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), - "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), - "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47) - : "l"(desc_a), + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), "l"(desc_b), "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); + "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x96x64_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x64x64_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// SPARSE GMMA 64x96x64 TN F32+=E4M3*E5M2 +// SPARSE GMMA 64x96x64 TN S32+=U8*S8 template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, GMMA::SparseSel spsel = GMMA::SparseSel::Zero > -struct GMMA_64x96x64_F32E4M3E5M2_RS_TN +struct GMMA_64x96x64_S32U8S8_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; using ERegisters = uint32_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = float[48]; + using CRegisters = uint32_t[48]; CUTE_HOST_DEVICE static void fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, uint64_t const& desc_b, - float & d00, float & d01, float & d02, float & d03, - float & d04, float & d05, float & d06, float & d07, - float & d08, float & d09, float & d10, float & d11, - float & d12, float & d13, float & d14, float & d15, - float & d16, float & d17, float & d18, float & d19, - float & d20, float & d21, float & d22, float & d23, - float & d24, float & d25, float & d26, float & d27, - float & d28, float & d29, float & d30, float & d31, - float & d32, float & d33, float & d34, float & d35, - float & d36, float & d37, float & d38, float & d39, - float & d40, float & d41, float & d42, float & d43, - float & d44, float & d45, float & d46, float & d47, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, uint32_t const& e, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" "setp.ne.b32 p, %55, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n96k64.f32.e4m3.e5m2 " + "wgmma.mma_async.sp.sync.aligned.m64n96k64.s32.u8.s8 " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " " %16, %17, %18, %19, %20, %21, %22, %23, " @@ -40135,49 +11064,46 @@ struct GMMA_64x96x64_F32E4M3E5M2_RS_TN "{%48, %49, %50, %51}," " %52," " %53, %54," - " p, %56, %57;\n" + " p;\n" "}\n" - : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), - "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), - "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), - "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), - "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), - "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), - "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), - "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), - "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), - "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), - "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), - "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47) + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) : "r"(a00), "r"(a01), "r"(a02), "r"(a03), "l"(desc_b), "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); + "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x96x64_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x96x64_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x112x64 TN F16+=E4M3*E5M2 +// SPARSE GMMA 64x96x64 TN S32+=U8*S8 template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, GMMA::SparseSel spsel = GMMA::SparseSel::Zero > -struct GMMA_64x112x64_F16E4M3E5M2_SS_TN +struct GMMA_64x96x64_S32U8S8_RS_TN_SATURATE { using DRegisters = void; - using ARegisters = uint64_t[1]; + using ARegisters = uint32_t[4]; using ERegisters = uint32_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[28]; + using CRegisters = uint32_t[48]; CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, uint64_t const& desc_b, uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, @@ -40186,23 +11112,31 @@ struct GMMA_64x112x64_F16E4M3E5M2_SS_TN uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, uint32_t const& e, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %32, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n112k64.f16.e4m3.e5m2 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27}," - " %28," - " %29," - " %30, %31," - " p, %33, %34;\n" + "setp.ne.b32 p, %55, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n96k64.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + "{%48, %49, %50, %51}," + " %52," + " %53, %54," + " p;\n" "}\n" : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), @@ -40210,34 +11144,35 @@ struct GMMA_64x112x64_F16E4M3E5M2_SS_TN "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27) - : "l"(desc_a), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), "l"(desc_b), "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); + "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x112x64_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x96x64_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x112x64 TN F16+=E4M3*E5M2 +// SPARSE GMMA 64x128x64 TN S32+=U8*S8 template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, GMMA::SparseSel spsel = GMMA::SparseSel::Zero > -struct GMMA_64x112x64_F16E4M3E5M2_RS_TN +struct GMMA_64x128x64_S32U8S8_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; using ERegisters = uint32_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[28]; + using CRegisters = uint32_t[64]; CUTE_HOST_DEVICE static void fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, @@ -40249,23 +11184,37 @@ struct GMMA_64x112x64_F16E4M3E5M2_RS_TN uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, uint32_t const& e, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %35, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n112k64.f16.e4m3.e5m2 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27}," - "{%28, %29, %30, %31}," - " %32," - " %33, %34," - " p, %36, %37;\n" + "setp.ne.b32 p, %71, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n128k64.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + "{%64, %65, %66, %67}," + " %68," + " %69, %70," + " p;\n" "}\n" : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), @@ -40273,196 +11222,226 @@ struct GMMA_64x112x64_F16E4M3E5M2_RS_TN "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27) + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) : "r"(a00), "r"(a01), "r"(a02), "r"(a03), "l"(desc_b), "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); + "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x112x64_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x128x64_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x112x64 TN F32+=E4M3*E5M2 +// SPARSE GMMA 64x128x64 TN S32+=U8*S8 template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, GMMA::SparseSel spsel = GMMA::SparseSel::Zero > -struct GMMA_64x112x64_F32E4M3E5M2_SS_TN +struct GMMA_64x128x64_S32U8S8_RS_TN_SATURATE { using DRegisters = void; - using ARegisters = uint64_t[1]; + using ARegisters = uint32_t[4]; using ERegisters = uint32_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = float[56]; + using CRegisters = uint32_t[64]; CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, uint64_t const& desc_b, - float & d00, float & d01, float & d02, float & d03, - float & d04, float & d05, float & d06, float & d07, - float & d08, float & d09, float & d10, float & d11, - float & d12, float & d13, float & d14, float & d15, - float & d16, float & d17, float & d18, float & d19, - float & d20, float & d21, float & d22, float & d23, - float & d24, float & d25, float & d26, float & d27, - float & d28, float & d29, float & d30, float & d31, - float & d32, float & d33, float & d34, float & d35, - float & d36, float & d37, float & d38, float & d39, - float & d40, float & d41, float & d42, float & d43, - float & d44, float & d45, float & d46, float & d47, - float & d48, float & d49, float & d50, float & d51, - float & d52, float & d53, float & d54, float & d55, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, uint32_t const& e, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %60, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n112k64.f32.e4m3.e5m2 " + "setp.ne.b32 p, %71, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n128k64.s32.u8.s8.satfinite " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " " %16, %17, %18, %19, %20, %21, %22, %23, " " %24, %25, %26, %27, %28, %29, %30, %31, " " %32, %33, %34, %35, %36, %37, %38, %39, " " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55}," - " %56," - " %57," - " %58, %59," - " p, %61, %62;\n" + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + "{%64, %65, %66, %67}," + " %68," + " %69, %70," + " p;\n" "}\n" - : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), - "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), - "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), - "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), - "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), - "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), - "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), - "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), - "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), - "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), - "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), - "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), - "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), - "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55) - : "l"(desc_a), + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), "l"(desc_b), "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); + "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x112x64_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x128x64_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x112x64 TN F32+=E4M3*E5M2 +// SPARSE GMMA 64x192x64 TN S32+=U8*S8 template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, GMMA::SparseSel spsel = GMMA::SparseSel::Zero > -struct GMMA_64x112x64_F32E4M3E5M2_RS_TN +struct GMMA_64x192x64_S32U8S8_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; using ERegisters = uint32_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = float[56]; + using CRegisters = uint32_t[96]; CUTE_HOST_DEVICE static void fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, uint64_t const& desc_b, - float & d00, float & d01, float & d02, float & d03, - float & d04, float & d05, float & d06, float & d07, - float & d08, float & d09, float & d10, float & d11, - float & d12, float & d13, float & d14, float & d15, - float & d16, float & d17, float & d18, float & d19, - float & d20, float & d21, float & d22, float & d23, - float & d24, float & d25, float & d26, float & d27, - float & d28, float & d29, float & d30, float & d31, - float & d32, float & d33, float & d34, float & d35, - float & d36, float & d37, float & d38, float & d39, - float & d40, float & d41, float & d42, float & d43, - float & d44, float & d45, float & d46, float & d47, - float & d48, float & d49, float & d50, float & d51, - float & d52, float & d53, float & d54, float & d55, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + uint32_t & d88, uint32_t & d89, uint32_t & d90, uint32_t & d91, + uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95, uint32_t const& e, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %63, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n112k64.f32.e4m3.e5m2 " + "setp.ne.b32 p, %103, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n192k64.s32.u8.s8 " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " " %16, %17, %18, %19, %20, %21, %22, %23, " " %24, %25, %26, %27, %28, %29, %30, %31, " " %32, %33, %34, %35, %36, %37, %38, %39, " " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55}," - "{%56, %57, %58, %59}," - " %60," - " %61, %62," - " p, %64, %65;\n" + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + "{%96, %97, %98, %99}," + " %100," + " %101, %102," + " p;\n" "}\n" - : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), - "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), - "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), - "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), - "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), - "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), - "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), - "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), - "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), - "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), - "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), - "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), - "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), - "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55) + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87), + "+r"(d88), "+r"(d89), "+r"(d90), "+r"(d91), + "+r"(d92), "+r"(d93), "+r"(d94), "+r"(d95) : "r"(a00), "r"(a01), "r"(a02), "r"(a03), "l"(desc_b), "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); + "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x112x64_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x192x64_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -// SPARSE GMMA 64x128x64 TN F16+=E4M3*E5M2 +// SPARSE GMMA 64x192x64 TN S32+=U8*S8 template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, GMMA::SparseSel spsel = GMMA::SparseSel::Zero > -struct GMMA_64x128x64_F16E4M3E5M2_SS_TN +struct GMMA_64x192x64_S32U8S8_RS_TN_SATURATE { using DRegisters = void; - using ARegisters = uint64_t[1]; + using ARegisters = uint32_t[4]; using ERegisters = uint32_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[32]; + using CRegisters = uint32_t[96]; CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, uint64_t const& desc_b, uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, @@ -40472,23 +11451,48 @@ struct GMMA_64x128x64_F16E4M3E5M2_SS_TN uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + uint32_t & d88, uint32_t & d89, uint32_t & d90, uint32_t & d91, + uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95, uint32_t const& e, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %36, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n128k64.f16.e4m3.e5m2 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31}," - " %32," - " %33," - " %34, %35," - " p, %37, %38;\n" + "setp.ne.b32 p, %103, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n192k64.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + "{%96, %97, %98, %99}," + " %100," + " %101, %102," + " p;\n" "}\n" : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), @@ -40497,124 +11501,214 @@ struct GMMA_64x128x64_F16E4M3E5M2_SS_TN "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) - : "l"(desc_a), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87), + "+r"(d88), "+r"(d89), "+r"(d90), "+r"(d91), + "+r"(d92), "+r"(d93), "+r"(d94), "+r"(d95) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), "l"(desc_b), "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); + "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x128x64_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x192x64_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// SPARSE GMMA 64x128x64 TN F16+=E4M3*E5M2 +// SPARSE GMMA 64x256x64 TN S32+=U8*S8 template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, GMMA::SparseSel spsel = GMMA::SparseSel::Zero > -struct GMMA_64x128x64_F16E4M3E5M2_RS_TN +struct GMMA_64x256x64_S32U8S8_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; using ERegisters = uint32_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[32]; + using CRegisters = uint32_t[128]; CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %39, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n128k64.f16.e4m3.e5m2 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31}," - "{%32, %33, %34, %35}," - " %36," - " %37, %38," - " p, %40, %41;\n" + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + uint32_t & d120, uint32_t & d121, uint32_t & d122, uint32_t & d123, + uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %135, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n256k64.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + "{%128, %129, %130, %131}," + " %132," + " %133, %134," + " p;\n" "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119), + "+r"(d120), "+r"(d121), "+r"(d122), "+r"(d123), + "+r"(d124), "+r"(d125), "+r"(d126), "+r"(d127) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), "l"(desc_b), "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); + "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x128x64_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x256x64_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// SPARSE GMMA 64x128x64 TN F32+=E4M3*E5M2 +// SPARSE GMMA 64x256x64 TN S32+=U8*S8 template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, GMMA::SparseSel spsel = GMMA::SparseSel::Zero > -struct GMMA_64x128x64_F32E4M3E5M2_SS_TN +struct GMMA_64x256x64_S32U8S8_RS_TN_SATURATE { using DRegisters = void; - using ARegisters = uint64_t[1]; + using ARegisters = uint32_t[4]; using ERegisters = uint32_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = float[64]; + using CRegisters = uint32_t[128]; CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, uint64_t const& desc_b, - float & d00, float & d01, float & d02, float & d03, - float & d04, float & d05, float & d06, float & d07, - float & d08, float & d09, float & d10, float & d11, - float & d12, float & d13, float & d14, float & d15, - float & d16, float & d17, float & d18, float & d19, - float & d20, float & d21, float & d22, float & d23, - float & d24, float & d25, float & d26, float & d27, - float & d28, float & d29, float & d30, float & d31, - float & d32, float & d33, float & d34, float & d35, - float & d36, float & d37, float & d38, float & d39, - float & d40, float & d41, float & d42, float & d43, - float & d44, float & d45, float & d46, float & d47, - float & d48, float & d49, float & d50, float & d51, - float & d52, float & d53, float & d54, float & d55, - float & d56, float & d57, float & d58, float & d59, - float & d60, float & d61, float & d62, float & d63, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + uint32_t & d120, uint32_t & d121, uint32_t & d122, uint32_t & d123, + uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127, uint32_t const& e, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %68, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n128k64.f32.e4m3.e5m2 " + "setp.ne.b32 p, %135, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n256k64.s32.u8.s8.satfinite " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " " %16, %17, %18, %19, %20, %21, %22, %23, " @@ -40622,453 +11716,311 @@ struct GMMA_64x128x64_F32E4M3E5M2_SS_TN " %32, %33, %34, %35, %36, %37, %38, %39, " " %40, %41, %42, %43, %44, %45, %46, %47, " " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63}," - " %64," - " %65," - " %66, %67," - " p, %69, %70;\n" + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + "{%128, %129, %130, %131}," + " %132," + " %133, %134," + " p;\n" "}\n" - : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), - "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), - "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), - "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), - "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), - "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), - "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), - "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), - "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), - "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), - "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), - "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), - "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), - "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), - "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), - "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63) - : "l"(desc_a), + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119), + "+r"(d120), "+r"(d121), "+r"(d122), "+r"(d123), + "+r"(d124), "+r"(d125), "+r"(d126), "+r"(d127) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), "l"(desc_b), "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); + "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x128x64_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x256x64_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// SPARSE GMMA 64x128x64 TN F32+=E4M3*E5M2 +// SPARSE GMMA 64x8x64 TN S32+=U8*U8 template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, GMMA::SparseSel spsel = GMMA::SparseSel::Zero > -struct GMMA_64x128x64_F32E4M3E5M2_RS_TN +struct GMMA_64x8x64_S32U8U8_SS_TN { using DRegisters = void; - using ARegisters = uint32_t[4]; + using ARegisters = uint64_t[1]; using ERegisters = uint32_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = float[64]; + using CRegisters = uint32_t[4]; CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + fma(uint64_t const& desc_a, uint64_t const& desc_b, - float & d00, float & d01, float & d02, float & d03, - float & d04, float & d05, float & d06, float & d07, - float & d08, float & d09, float & d10, float & d11, - float & d12, float & d13, float & d14, float & d15, - float & d16, float & d17, float & d18, float & d19, - float & d20, float & d21, float & d22, float & d23, - float & d24, float & d25, float & d26, float & d27, - float & d28, float & d29, float & d30, float & d31, - float & d32, float & d33, float & d34, float & d35, - float & d36, float & d37, float & d38, float & d39, - float & d40, float & d41, float & d42, float & d43, - float & d44, float & d45, float & d46, float & d47, - float & d48, float & d49, float & d50, float & d51, - float & d52, float & d53, float & d54, float & d55, - float & d56, float & d57, float & d58, float & d59, - float & d60, float & d61, float & d62, float & d63, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, uint32_t const& e, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %71, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n128k64.f32.e4m3.e5m2 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63}," - "{%64, %65, %66, %67}," - " %68," - " %69, %70," - " p, %72, %73;\n" + "setp.ne.b32 p, %8, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n8k64.s32.u8.u8 " + "{%0, %1, %2, %3}," + " %4," + " %5," + " %6, %7," + " p;\n" "}\n" - : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), - "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), - "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), - "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), - "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), - "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), - "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), - "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), - "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), - "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), - "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), - "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), - "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), - "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), - "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), - "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) + : "l"(desc_a), "l"(desc_b), "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); + "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x128x64_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x8x64_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x144x64 TN F16+=E4M3*E5M2 +// SPARSE GMMA 64x8x64 TN S32+=U8*U8 template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, GMMA::SparseSel spsel = GMMA::SparseSel::Zero > -struct GMMA_64x144x64_F16E4M3E5M2_SS_TN +struct GMMA_64x8x64_S32U8U8_SS_TN_SATURATE { using DRegisters = void; using ARegisters = uint64_t[1]; using ERegisters = uint32_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[36]; + using CRegisters = uint32_t[4]; CUTE_HOST_DEVICE static void fma(uint64_t const& desc_a, uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, uint32_t const& e, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %40, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n144k64.f16.e4m3.e5m2 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35}," - " %36," - " %37," - " %38, %39," - " p, %41, %42;\n" + "setp.ne.b32 p, %8, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n8k64.s32.u8.u8.satfinite " + "{%0, %1, %2, %3}," + " %4," + " %5," + " %6, %7," + " p;\n" "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35) + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) : "l"(desc_a), "l"(desc_b), "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); + "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x144x64_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x8x64_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x144x64 TN F16+=E4M3*E5M2 +// SPARSE GMMA 64x16x64 TN S32+=U8*U8 template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, GMMA::SparseSel spsel = GMMA::SparseSel::Zero > -struct GMMA_64x144x64_F16E4M3E5M2_RS_TN +struct GMMA_64x16x64_S32U8U8_SS_TN { using DRegisters = void; - using ARegisters = uint32_t[4]; + using ARegisters = uint64_t[1]; using ERegisters = uint32_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[36]; + using CRegisters = uint32_t[8]; CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + fma(uint64_t const& desc_a, uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, uint32_t const& e, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %43, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n144k64.f16.e4m3.e5m2 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35}," - "{%36, %37, %38, %39}," - " %40," - " %41, %42," - " p, %44, %45;\n" + "setp.ne.b32 p, %12, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n16k64.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + " %8," + " %9," + " %10, %11," + " p;\n" "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) + : "l"(desc_a), "l"(desc_b), "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); + "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x144x64_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x16x64_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x144x64 TN F32+=E4M3*E5M2 +// SPARSE GMMA 64x16x64 TN S32+=U8*U8 template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, GMMA::SparseSel spsel = GMMA::SparseSel::Zero > -struct GMMA_64x144x64_F32E4M3E5M2_SS_TN +struct GMMA_64x16x64_S32U8U8_SS_TN_SATURATE { using DRegisters = void; using ARegisters = uint64_t[1]; using ERegisters = uint32_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = float[72]; + using CRegisters = uint32_t[8]; CUTE_HOST_DEVICE static void fma(uint64_t const& desc_a, uint64_t const& desc_b, - float & d00, float & d01, float & d02, float & d03, - float & d04, float & d05, float & d06, float & d07, - float & d08, float & d09, float & d10, float & d11, - float & d12, float & d13, float & d14, float & d15, - float & d16, float & d17, float & d18, float & d19, - float & d20, float & d21, float & d22, float & d23, - float & d24, float & d25, float & d26, float & d27, - float & d28, float & d29, float & d30, float & d31, - float & d32, float & d33, float & d34, float & d35, - float & d36, float & d37, float & d38, float & d39, - float & d40, float & d41, float & d42, float & d43, - float & d44, float & d45, float & d46, float & d47, - float & d48, float & d49, float & d50, float & d51, - float & d52, float & d53, float & d54, float & d55, - float & d56, float & d57, float & d58, float & d59, - float & d60, float & d61, float & d62, float & d63, - float & d64, float & d65, float & d66, float & d67, - float & d68, float & d69, float & d70, float & d71, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, uint32_t const& e, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %76, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n144k64.f32.e4m3.e5m2 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71}," - " %72," - " %73," - " %74, %75," - " p, %77, %78;\n" + "setp.ne.b32 p, %12, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n16k64.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + " %8," + " %9," + " %10, %11," + " p;\n" "}\n" - : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), - "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), - "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), - "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), - "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), - "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), - "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), - "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), - "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), - "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), - "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), - "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), - "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), - "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), - "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), - "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), - "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), - "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71) + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) : "l"(desc_a), "l"(desc_b), "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); + "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x144x64_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x16x64_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x144x64 TN F32+=E4M3*E5M2 +// SPARSE GMMA 64x32x64 TN S32+=U8*U8 template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, GMMA::SparseSel spsel = GMMA::SparseSel::Zero > -struct GMMA_64x144x64_F32E4M3E5M2_RS_TN +struct GMMA_64x32x64_S32U8U8_SS_TN { using DRegisters = void; - using ARegisters = uint32_t[4]; + using ARegisters = uint64_t[1]; using ERegisters = uint32_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = float[72]; + using CRegisters = uint32_t[16]; CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + fma(uint64_t const& desc_a, uint64_t const& desc_b, - float & d00, float & d01, float & d02, float & d03, - float & d04, float & d05, float & d06, float & d07, - float & d08, float & d09, float & d10, float & d11, - float & d12, float & d13, float & d14, float & d15, - float & d16, float & d17, float & d18, float & d19, - float & d20, float & d21, float & d22, float & d23, - float & d24, float & d25, float & d26, float & d27, - float & d28, float & d29, float & d30, float & d31, - float & d32, float & d33, float & d34, float & d35, - float & d36, float & d37, float & d38, float & d39, - float & d40, float & d41, float & d42, float & d43, - float & d44, float & d45, float & d46, float & d47, - float & d48, float & d49, float & d50, float & d51, - float & d52, float & d53, float & d54, float & d55, - float & d56, float & d57, float & d58, float & d59, - float & d60, float & d61, float & d62, float & d63, - float & d64, float & d65, float & d66, float & d67, - float & d68, float & d69, float & d70, float & d71, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, uint32_t const& e, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %79, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n144k64.f32.e4m3.e5m2 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71}," - "{%72, %73, %74, %75}," - " %76," - " %77, %78," - " p, %80, %81;\n" + "setp.ne.b32 p, %20, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n32k64.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + " %16," + " %17," + " %18, %19," + " p;\n" "}\n" - : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), - "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), - "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), - "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), - "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), - "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), - "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), - "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), - "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), - "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), - "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), - "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), - "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), - "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), - "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), - "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), - "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), - "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) + : "l"(desc_a), "l"(desc_b), "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); + "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x144x64_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x32x64_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x160x64 TN F16+=E4M3*E5M2 +// SPARSE GMMA 64x32x64 TN S32+=U8*U8 template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, GMMA::SparseSel spsel = GMMA::SparseSel::Zero > -struct GMMA_64x160x64_F16E4M3E5M2_SS_TN +struct GMMA_64x32x64_S32U8U8_SS_TN_SATURATE { using DRegisters = void; using ARegisters = uint64_t[1]; using ERegisters = uint32_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[40]; + using CRegisters = uint32_t[16]; CUTE_HOST_DEVICE static void fma(uint64_t const& desc_a, @@ -41077,71 +12029,53 @@ struct GMMA_64x160x64_F16E4M3E5M2_SS_TN uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, uint32_t const& e, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %44, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n160k64.f16.e4m3.e5m2 " + "setp.ne.b32 p, %20, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n32k64.s32.u8.u8.satfinite " "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39}," - " %40," - " %41," - " %42, %43," - " p, %45, %46;\n" + " %8, %9, %10, %11, %12, %13, %14, %15}," + " %16," + " %17," + " %18, %19," + " p;\n" "}\n" : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) : "l"(desc_a), "l"(desc_b), "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); + "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x160x64_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x32x64_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x160x64 TN F16+=E4M3*E5M2 +// SPARSE GMMA 64x64x64 TN S32+=U8*U8 template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, GMMA::SparseSel spsel = GMMA::SparseSel::Zero > -struct GMMA_64x160x64_F16E4M3E5M2_RS_TN +struct GMMA_64x64x64_S32U8U8_SS_TN { using DRegisters = void; - using ARegisters = uint32_t[4]; + using ARegisters = uint64_t[1]; using ERegisters = uint32_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[40]; + using CRegisters = uint32_t[32]; CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + fma(uint64_t const& desc_a, uint64_t const& desc_b, uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, @@ -41151,26 +12085,24 @@ struct GMMA_64x160x64_F16E4M3E5M2_RS_TN uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, uint32_t const& e, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %47, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n160k64.f16.e4m3.e5m2 " + "setp.ne.b32 p, %36, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n64k64.s32.u8.u8 " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39}," - "{%40, %41, %42, %43}," - " %44," - " %45, %46," - " p, %48, %49;\n" + " %24, %25, %26, %27, %28, %29, %30, %31}," + " %32," + " %33," + " %34, %35," + " p;\n" "}\n" : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), @@ -41179,226 +12111,92 @@ struct GMMA_64x160x64_F16E4M3E5M2_RS_TN "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) + : "l"(desc_a), "l"(desc_b), "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); + "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x160x64_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x64x64_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x160x64 TN F32+=E4M3*E5M2 +// SPARSE GMMA 64x64x64 TN S32+=U8*U8 template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, GMMA::SparseSel spsel = GMMA::SparseSel::Zero > -struct GMMA_64x160x64_F32E4M3E5M2_SS_TN +struct GMMA_64x64x64_S32U8U8_SS_TN_SATURATE { using DRegisters = void; using ARegisters = uint64_t[1]; using ERegisters = uint32_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = float[80]; + using CRegisters = uint32_t[32]; CUTE_HOST_DEVICE static void fma(uint64_t const& desc_a, uint64_t const& desc_b, - float & d00, float & d01, float & d02, float & d03, - float & d04, float & d05, float & d06, float & d07, - float & d08, float & d09, float & d10, float & d11, - float & d12, float & d13, float & d14, float & d15, - float & d16, float & d17, float & d18, float & d19, - float & d20, float & d21, float & d22, float & d23, - float & d24, float & d25, float & d26, float & d27, - float & d28, float & d29, float & d30, float & d31, - float & d32, float & d33, float & d34, float & d35, - float & d36, float & d37, float & d38, float & d39, - float & d40, float & d41, float & d42, float & d43, - float & d44, float & d45, float & d46, float & d47, - float & d48, float & d49, float & d50, float & d51, - float & d52, float & d53, float & d54, float & d55, - float & d56, float & d57, float & d58, float & d59, - float & d60, float & d61, float & d62, float & d63, - float & d64, float & d65, float & d66, float & d67, - float & d68, float & d69, float & d70, float & d71, - float & d72, float & d73, float & d74, float & d75, - float & d76, float & d77, float & d78, float & d79, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, uint32_t const& e, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %84, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n160k64.f32.e4m3.e5m2 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79}," - " %80," - " %81," - " %82, %83," - " p, %85, %86;\n" + "setp.ne.b32 p, %36, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n64k64.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + " %32," + " %33," + " %34, %35," + " p;\n" "}\n" - : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), - "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), - "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), - "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), - "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), - "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), - "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), - "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), - "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), - "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), - "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), - "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), - "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), - "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), - "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), - "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), - "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), - "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), - "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), - "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79) + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) : "l"(desc_a), "l"(desc_b), "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x160x64_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x160x64 TN F32+=E4M3*E5M2 -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x160x64_F32E4M3E5M2_RS_TN -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = float[80]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, - uint64_t const& desc_b, - float & d00, float & d01, float & d02, float & d03, - float & d04, float & d05, float & d06, float & d07, - float & d08, float & d09, float & d10, float & d11, - float & d12, float & d13, float & d14, float & d15, - float & d16, float & d17, float & d18, float & d19, - float & d20, float & d21, float & d22, float & d23, - float & d24, float & d25, float & d26, float & d27, - float & d28, float & d29, float & d30, float & d31, - float & d32, float & d33, float & d34, float & d35, - float & d36, float & d37, float & d38, float & d39, - float & d40, float & d41, float & d42, float & d43, - float & d44, float & d45, float & d46, float & d47, - float & d48, float & d49, float & d50, float & d51, - float & d52, float & d53, float & d54, float & d55, - float & d56, float & d57, float & d58, float & d59, - float & d60, float & d61, float & d62, float & d63, - float & d64, float & d65, float & d66, float & d67, - float & d68, float & d69, float & d70, float & d71, - float & d72, float & d73, float & d74, float & d75, - float & d76, float & d77, float & d78, float & d79, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %87, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n160k64.f32.e4m3.e5m2 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79}," - "{%80, %81, %82, %83}," - " %84," - " %85, %86," - " p, %88, %89;\n" - "}\n" - : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), - "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), - "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), - "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), - "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), - "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), - "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), - "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), - "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), - "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), - "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), - "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), - "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), - "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), - "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), - "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), - "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), - "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), - "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), - "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); + "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x160x64_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x64x64_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x176x64 TN F16+=E4M3*E5M2 +// SPARSE GMMA 64x96x64 TN S32+=U8*U8 template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, GMMA::SparseSel spsel = GMMA::SparseSel::Zero > -struct GMMA_64x176x64_F16E4M3E5M2_SS_TN +struct GMMA_64x96x64_S32U8U8_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; using ERegisters = uint32_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[44]; + using CRegisters = uint32_t[48]; CUTE_HOST_DEVICE static void fma(uint64_t const& desc_a, @@ -41414,25 +12212,27 @@ struct GMMA_64x176x64_F16E4M3E5M2_SS_TN uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, uint32_t const& e, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %48, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n176k64.f16.e4m3.e5m2 " + "setp.ne.b32 p, %52, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n96k64.s32.u8.u8 " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " " %16, %17, %18, %19, %20, %21, %22, %23, " " %24, %25, %26, %27, %28, %29, %30, %31, " " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43}," - " %44," - " %45," - " %46, %47," - " p, %49, %50;\n" + " %40, %41, %42, %43, %44, %45, %46, %47}," + " %48," + " %49," + " %50, %51," + " p;\n" "}\n" : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), @@ -41444,37 +12244,34 @@ struct GMMA_64x176x64_F16E4M3E5M2_SS_TN "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43) + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) : "l"(desc_a), "l"(desc_b), "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); + "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x176x64_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x96x64_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x176x64 TN F16+=E4M3*E5M2 +// SPARSE GMMA 64x96x64 TN S32+=U8*U8 template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, GMMA::SparseSel spsel = GMMA::SparseSel::Zero > -struct GMMA_64x176x64_F16E4M3E5M2_RS_TN +struct GMMA_64x96x64_S32U8U8_SS_TN_SATURATE { using DRegisters = void; - using ARegisters = uint32_t[4]; + using ARegisters = uint64_t[1]; using ERegisters = uint32_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[44]; + using CRegisters = uint32_t[48]; CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + fma(uint64_t const& desc_a, uint64_t const& desc_b, uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, @@ -41487,25 +12284,27 @@ struct GMMA_64x176x64_F16E4M3E5M2_RS_TN uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, uint32_t const& e, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %51, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n176k64.f16.e4m3.e5m2 " + "setp.ne.b32 p, %52, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n96k64.s32.u8.u8.satfinite " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " " %16, %17, %18, %19, %20, %21, %22, %23, " " %24, %25, %26, %27, %28, %29, %30, %31, " " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43}," - "{%44, %45, %46, %47}," + " %40, %41, %42, %43, %44, %45, %46, %47}," " %48," - " %49, %50," - " p, %52, %53;\n" + " %49," + " %50, %51," + " p;\n" "}\n" : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), @@ -41517,69 +12316,61 @@ struct GMMA_64x176x64_F16E4M3E5M2_RS_TN "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) + : "l"(desc_a), "l"(desc_b), "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); + "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x176x64_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x96x64_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x176x64 TN F32+=E4M3*E5M2 +// SPARSE GMMA 64x128x64 TN S32+=U8*U8 template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, GMMA::SparseSel spsel = GMMA::SparseSel::Zero > -struct GMMA_64x176x64_F32E4M3E5M2_SS_TN +struct GMMA_64x128x64_S32U8U8_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; using ERegisters = uint32_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = float[88]; + using CRegisters = uint32_t[64]; CUTE_HOST_DEVICE static void fma(uint64_t const& desc_a, uint64_t const& desc_b, - float & d00, float & d01, float & d02, float & d03, - float & d04, float & d05, float & d06, float & d07, - float & d08, float & d09, float & d10, float & d11, - float & d12, float & d13, float & d14, float & d15, - float & d16, float & d17, float & d18, float & d19, - float & d20, float & d21, float & d22, float & d23, - float & d24, float & d25, float & d26, float & d27, - float & d28, float & d29, float & d30, float & d31, - float & d32, float & d33, float & d34, float & d35, - float & d36, float & d37, float & d38, float & d39, - float & d40, float & d41, float & d42, float & d43, - float & d44, float & d45, float & d46, float & d47, - float & d48, float & d49, float & d50, float & d51, - float & d52, float & d53, float & d54, float & d55, - float & d56, float & d57, float & d58, float & d59, - float & d60, float & d61, float & d62, float & d63, - float & d64, float & d65, float & d66, float & d67, - float & d68, float & d69, float & d70, float & d71, - float & d72, float & d73, float & d74, float & d75, - float & d76, float & d77, float & d78, float & d79, - float & d80, float & d81, float & d82, float & d83, - float & d84, float & d85, float & d86, float & d87, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, uint32_t const& e, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %92, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n176k64.f32.e4m3.e5m2 " + "setp.ne.b32 p, %68, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n128k64.s32.u8.u8 " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " " %16, %17, %18, %19, %20, %21, %22, %23, " @@ -41587,99 +12378,81 @@ struct GMMA_64x176x64_F32E4M3E5M2_SS_TN " %32, %33, %34, %35, %36, %37, %38, %39, " " %40, %41, %42, %43, %44, %45, %46, %47, " " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87}," - " %88," - " %89," - " %90, %91," - " p, %93, %94;\n" + " %56, %57, %58, %59, %60, %61, %62, %63}," + " %64," + " %65," + " %66, %67," + " p;\n" "}\n" - : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), - "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), - "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), - "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), - "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), - "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), - "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), - "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), - "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), - "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), - "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), - "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), - "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), - "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), - "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), - "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), - "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), - "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), - "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), - "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), - "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), - "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87) + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) : "l"(desc_a), "l"(desc_b), "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); + "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x176x64_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x128x64_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x176x64 TN F32+=E4M3*E5M2 +// SPARSE GMMA 64x128x64 TN S32+=U8*U8 template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, GMMA::SparseSel spsel = GMMA::SparseSel::Zero > -struct GMMA_64x176x64_F32E4M3E5M2_RS_TN +struct GMMA_64x128x64_S32U8U8_SS_TN_SATURATE { using DRegisters = void; - using ARegisters = uint32_t[4]; + using ARegisters = uint64_t[1]; using ERegisters = uint32_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = float[88]; + using CRegisters = uint32_t[64]; CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + fma(uint64_t const& desc_a, uint64_t const& desc_b, - float & d00, float & d01, float & d02, float & d03, - float & d04, float & d05, float & d06, float & d07, - float & d08, float & d09, float & d10, float & d11, - float & d12, float & d13, float & d14, float & d15, - float & d16, float & d17, float & d18, float & d19, - float & d20, float & d21, float & d22, float & d23, - float & d24, float & d25, float & d26, float & d27, - float & d28, float & d29, float & d30, float & d31, - float & d32, float & d33, float & d34, float & d35, - float & d36, float & d37, float & d38, float & d39, - float & d40, float & d41, float & d42, float & d43, - float & d44, float & d45, float & d46, float & d47, - float & d48, float & d49, float & d50, float & d51, - float & d52, float & d53, float & d54, float & d55, - float & d56, float & d57, float & d58, float & d59, - float & d60, float & d61, float & d62, float & d63, - float & d64, float & d65, float & d66, float & d67, - float & d68, float & d69, float & d70, float & d71, - float & d72, float & d73, float & d74, float & d75, - float & d76, float & d77, float & d78, float & d79, - float & d80, float & d81, float & d82, float & d83, - float & d84, float & d85, float & d86, float & d87, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, uint32_t const& e, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %95, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n176k64.f32.e4m3.e5m2 " + "setp.ne.b32 p, %68, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n128k64.s32.u8.u8.satfinite " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " " %16, %17, %18, %19, %20, %21, %22, %23, " @@ -41687,63 +12460,51 @@ struct GMMA_64x176x64_F32E4M3E5M2_RS_TN " %32, %33, %34, %35, %36, %37, %38, %39, " " %40, %41, %42, %43, %44, %45, %46, %47, " " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87}," - "{%88, %89, %90, %91}," - " %92," - " %93, %94," - " p, %96, %97;\n" + " %56, %57, %58, %59, %60, %61, %62, %63}," + " %64," + " %65," + " %66, %67," + " p;\n" "}\n" - : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), - "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), - "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), - "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), - "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), - "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), - "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), - "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), - "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), - "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), - "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), - "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), - "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), - "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), - "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), - "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), - "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), - "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), - "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), - "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), - "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), - "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) + : "l"(desc_a), "l"(desc_b), "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); + "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x176x64_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x128x64_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -// SPARSE GMMA 64x192x64 TN F16+=E4M3*E5M2 +// SPARSE GMMA 64x192x64 TN S32+=U8*U8 template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, GMMA::SparseSel spsel = GMMA::SparseSel::Zero > -struct GMMA_64x192x64_F16E4M3E5M2_SS_TN +struct GMMA_64x192x64_S32U8U8_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; using ERegisters = uint32_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[48]; + using CRegisters = uint32_t[96]; CUTE_HOST_DEVICE static void fma(uint64_t const& desc_a, @@ -41760,25 +12521,44 @@ struct GMMA_64x192x64_F16E4M3E5M2_SS_TN uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + uint32_t & d88, uint32_t & d89, uint32_t & d90, uint32_t & d91, + uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95, uint32_t const& e, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %52, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n192k64.f16.e4m3.e5m2 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47}," - " %48," - " %49," - " %50, %51," - " p, %53, %54;\n" + "setp.ne.b32 p, %100, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n192k64.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + " %96," + " %97," + " %98, %99," + " p;\n" "}\n" : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), @@ -41791,35 +12571,45 @@ struct GMMA_64x192x64_F16E4M3E5M2_SS_TN "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), - "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87), + "+r"(d88), "+r"(d89), "+r"(d90), "+r"(d91), + "+r"(d92), "+r"(d93), "+r"(d94), "+r"(d95) : "l"(desc_a), "l"(desc_b), "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); + "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x192x64_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x192x64_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// SPARSE GMMA 64x192x64 TN F16+=E4M3*E5M2 +// SPARSE GMMA 64x192x64 TN S32+=U8*U8 template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, GMMA::SparseSel spsel = GMMA::SparseSel::Zero > -struct GMMA_64x192x64_F16E4M3E5M2_RS_TN +struct GMMA_64x192x64_S32U8U8_SS_TN_SATURATE { using DRegisters = void; - using ARegisters = uint32_t[4]; + using ARegisters = uint64_t[1]; using ERegisters = uint32_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[48]; + using CRegisters = uint32_t[96]; CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + fma(uint64_t const& desc_a, uint64_t const& desc_b, uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, @@ -41833,25 +12623,44 @@ struct GMMA_64x192x64_F16E4M3E5M2_RS_TN uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + uint32_t & d88, uint32_t & d89, uint32_t & d90, uint32_t & d91, + uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95, uint32_t const& e, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %55, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n192k64.f16.e4m3.e5m2 " + "setp.ne.b32 p, %100, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n192k64.s32.u8.u8.satfinite " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " " %16, %17, %18, %19, %20, %21, %22, %23, " " %24, %25, %26, %27, %28, %29, %30, %31, " " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47}," - "{%48, %49, %50, %51}," - " %52," - " %53, %54," - " p, %56, %57;\n" + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + " %96," + " %97," + " %98, %99," + " p;\n" "}\n" : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), @@ -41864,69 +12673,88 @@ struct GMMA_64x192x64_F16E4M3E5M2_RS_TN "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), - "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87), + "+r"(d88), "+r"(d89), "+r"(d90), "+r"(d91), + "+r"(d92), "+r"(d93), "+r"(d94), "+r"(d95) + : "l"(desc_a), "l"(desc_b), "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); + "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x192x64_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x192x64_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// SPARSE GMMA 64x192x64 TN F32+=E4M3*E5M2 +// SPARSE GMMA 64x256x64 TN S32+=U8*U8 template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, GMMA::SparseSel spsel = GMMA::SparseSel::Zero > -struct GMMA_64x192x64_F32E4M3E5M2_SS_TN +struct GMMA_64x256x64_S32U8U8_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; using ERegisters = uint32_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = float[96]; + using CRegisters = uint32_t[128]; CUTE_HOST_DEVICE static void fma(uint64_t const& desc_a, uint64_t const& desc_b, - float & d00, float & d01, float & d02, float & d03, - float & d04, float & d05, float & d06, float & d07, - float & d08, float & d09, float & d10, float & d11, - float & d12, float & d13, float & d14, float & d15, - float & d16, float & d17, float & d18, float & d19, - float & d20, float & d21, float & d22, float & d23, - float & d24, float & d25, float & d26, float & d27, - float & d28, float & d29, float & d30, float & d31, - float & d32, float & d33, float & d34, float & d35, - float & d36, float & d37, float & d38, float & d39, - float & d40, float & d41, float & d42, float & d43, - float & d44, float & d45, float & d46, float & d47, - float & d48, float & d49, float & d50, float & d51, - float & d52, float & d53, float & d54, float & d55, - float & d56, float & d57, float & d58, float & d59, - float & d60, float & d61, float & d62, float & d63, - float & d64, float & d65, float & d66, float & d67, - float & d68, float & d69, float & d70, float & d71, - float & d72, float & d73, float & d74, float & d75, - float & d76, float & d77, float & d78, float & d79, - float & d80, float & d81, float & d82, float & d83, - float & d84, float & d85, float & d86, float & d87, - float & d88, float & d89, float & d90, float & d91, - float & d92, float & d93, float & d94, float & d95, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + uint32_t & d120, uint32_t & d121, uint32_t & d122, uint32_t & d123, + uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127, uint32_t const& e, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %100, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n192k64.f32.e4m3.e5m2 " + "setp.ne.b32 p, %132, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n256k64.s32.u8.u8 " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " " %16, %17, %18, %19, %20, %21, %22, %23, " @@ -41938,98 +12766,117 @@ struct GMMA_64x192x64_F32E4M3E5M2_SS_TN " %64, %65, %66, %67, %68, %69, %70, %71, " " %72, %73, %74, %75, %76, %77, %78, %79, " " %80, %81, %82, %83, %84, %85, %86, %87, " - " %88, %89, %90, %91, %92, %93, %94, %95}," - " %96," - " %97," - " %98, %99," - " p, %101, %102;\n" + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + " %128," + " %129," + " %130, %131," + " p;\n" "}\n" - : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), - "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), - "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), - "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), - "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), - "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), - "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), - "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), - "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), - "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), - "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), - "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), - "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), - "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), - "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), - "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), - "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), - "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), - "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), - "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), - "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), - "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), - "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91), - "+f"(d92), "+f"(d93), "+f"(d94), "+f"(d95) + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119), + "+r"(d120), "+r"(d121), "+r"(d122), "+r"(d123), + "+r"(d124), "+r"(d125), "+r"(d126), "+r"(d127) : "l"(desc_a), "l"(desc_b), "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); + "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x192x64_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x256x64_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// SPARSE GMMA 64x192x64 TN F32+=E4M3*E5M2 +// SPARSE GMMA 64x256x64 TN S32+=U8*U8 template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, GMMA::SparseSel spsel = GMMA::SparseSel::Zero > -struct GMMA_64x192x64_F32E4M3E5M2_RS_TN +struct GMMA_64x256x64_S32U8U8_SS_TN_SATURATE { using DRegisters = void; - using ARegisters = uint32_t[4]; + using ARegisters = uint64_t[1]; using ERegisters = uint32_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = float[96]; + using CRegisters = uint32_t[128]; CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + fma(uint64_t const& desc_a, uint64_t const& desc_b, - float & d00, float & d01, float & d02, float & d03, - float & d04, float & d05, float & d06, float & d07, - float & d08, float & d09, float & d10, float & d11, - float & d12, float & d13, float & d14, float & d15, - float & d16, float & d17, float & d18, float & d19, - float & d20, float & d21, float & d22, float & d23, - float & d24, float & d25, float & d26, float & d27, - float & d28, float & d29, float & d30, float & d31, - float & d32, float & d33, float & d34, float & d35, - float & d36, float & d37, float & d38, float & d39, - float & d40, float & d41, float & d42, float & d43, - float & d44, float & d45, float & d46, float & d47, - float & d48, float & d49, float & d50, float & d51, - float & d52, float & d53, float & d54, float & d55, - float & d56, float & d57, float & d58, float & d59, - float & d60, float & d61, float & d62, float & d63, - float & d64, float & d65, float & d66, float & d67, - float & d68, float & d69, float & d70, float & d71, - float & d72, float & d73, float & d74, float & d75, - float & d76, float & d77, float & d78, float & d79, - float & d80, float & d81, float & d82, float & d83, - float & d84, float & d85, float & d86, float & d87, - float & d88, float & d89, float & d90, float & d91, - float & d92, float & d93, float & d94, float & d95, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + uint32_t & d120, uint32_t & d121, uint32_t & d122, uint32_t & d123, + uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127, uint32_t const& e, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %103, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n192k64.f32.e4m3.e5m2 " + "setp.ne.b32 p, %132, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n256k64.s32.u8.u8.satfinite " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " " %16, %17, %18, %19, %20, %21, %22, %23, " @@ -42041,518 +12888,359 @@ struct GMMA_64x192x64_F32E4M3E5M2_RS_TN " %64, %65, %66, %67, %68, %69, %70, %71, " " %72, %73, %74, %75, %76, %77, %78, %79, " " %80, %81, %82, %83, %84, %85, %86, %87, " - " %88, %89, %90, %91, %92, %93, %94, %95}," - "{%96, %97, %98, %99}," - " %100," - " %101, %102," - " p, %104, %105;\n" + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + " %128," + " %129," + " %130, %131," + " p;\n" "}\n" - : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), - "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), - "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), - "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), - "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), - "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), - "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), - "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), - "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), - "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), - "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), - "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), - "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), - "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), - "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), - "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), - "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), - "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), - "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), - "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), - "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), - "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), - "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91), - "+f"(d92), "+f"(d93), "+f"(d94), "+f"(d95) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119), + "+r"(d120), "+r"(d121), "+r"(d122), "+r"(d123), + "+r"(d124), "+r"(d125), "+r"(d126), "+r"(d127) + : "l"(desc_a), "l"(desc_b), "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); + "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x192x64_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x256x64_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x208x64 TN F16+=E4M3*E5M2 +// SPARSE GMMA 64x8x64 TN S32+=U8*U8 template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, GMMA::SparseSel spsel = GMMA::SparseSel::Zero > -struct GMMA_64x208x64_F16E4M3E5M2_SS_TN +struct GMMA_64x8x64_S32U8U8_RS_TN { using DRegisters = void; - using ARegisters = uint64_t[1]; + using ARegisters = uint32_t[4]; using ERegisters = uint32_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[52]; + using CRegisters = uint32_t[4]; CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, - uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, uint32_t const& e, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %56, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n208k64.f16.e4m3.e5m2 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51}," - " %52," - " %53," - " %54, %55," - " p, %57, %58;\n" + "setp.ne.b32 p, %11, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n8k64.s32.u8.u8 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + " %8," + " %9, %10," + " p;\n" "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), - "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), - "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51) - : "l"(desc_a), + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), "l"(desc_b), "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); + "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x208x64_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x8x64_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x8x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x8x64_S32U8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %11, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n8k64.s32.u8.u8.satfinite " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + " %8," + " %9, %10," + " p;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x8x64_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif + } +}; //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x208x64 TN F16+=E4M3*E5M2 +// SPARSE GMMA 64x16x64 TN S32+=U8*U8 template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, GMMA::SparseSel spsel = GMMA::SparseSel::Zero > -struct GMMA_64x208x64_F16E4M3E5M2_RS_TN +struct GMMA_64x16x64_S32U8U8_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; using ERegisters = uint32_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[52]; + using CRegisters = uint32_t[8]; CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, - uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, uint32_t const& e, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %59, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n208k64.f16.e4m3.e5m2 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51}," - "{%52, %53, %54, %55}," - " %56," - " %57, %58," - " p, %60, %61;\n" + "setp.ne.b32 p, %15, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n16k64.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + "{%8, %9, %10, %11}," + " %12," + " %13, %14," + " p;\n" "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), - "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), - "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), "l"(desc_b), "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); + "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x208x64_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x16x64_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x208x64 TN F32+=E4M3*E5M2 +// SPARSE GMMA 64x16x64 TN S32+=U8*U8 template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, GMMA::SparseSel spsel = GMMA::SparseSel::Zero > -struct GMMA_64x208x64_F32E4M3E5M2_SS_TN +struct GMMA_64x16x64_S32U8U8_RS_TN_SATURATE { using DRegisters = void; - using ARegisters = uint64_t[1]; + using ARegisters = uint32_t[4]; using ERegisters = uint32_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = float[104]; + using CRegisters = uint32_t[8]; CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, uint64_t const& desc_b, - float & d000, float & d001, float & d002, float & d003, - float & d004, float & d005, float & d006, float & d007, - float & d008, float & d009, float & d010, float & d011, - float & d012, float & d013, float & d014, float & d015, - float & d016, float & d017, float & d018, float & d019, - float & d020, float & d021, float & d022, float & d023, - float & d024, float & d025, float & d026, float & d027, - float & d028, float & d029, float & d030, float & d031, - float & d032, float & d033, float & d034, float & d035, - float & d036, float & d037, float & d038, float & d039, - float & d040, float & d041, float & d042, float & d043, - float & d044, float & d045, float & d046, float & d047, - float & d048, float & d049, float & d050, float & d051, - float & d052, float & d053, float & d054, float & d055, - float & d056, float & d057, float & d058, float & d059, - float & d060, float & d061, float & d062, float & d063, - float & d064, float & d065, float & d066, float & d067, - float & d068, float & d069, float & d070, float & d071, - float & d072, float & d073, float & d074, float & d075, - float & d076, float & d077, float & d078, float & d079, - float & d080, float & d081, float & d082, float & d083, - float & d084, float & d085, float & d086, float & d087, - float & d088, float & d089, float & d090, float & d091, - float & d092, float & d093, float & d094, float & d095, - float & d096, float & d097, float & d098, float & d099, - float & d100, float & d101, float & d102, float & d103, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, uint32_t const& e, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %108, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n208k64.f32.e4m3.e5m2 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87, " - " %88, %89, %90, %91, %92, %93, %94, %95, " - " %96, %97, %98, %99, %100, %101, %102, %103}," - " %104," - " %105," - " %106, %107," - " p, %109, %110;\n" + "setp.ne.b32 p, %15, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n16k64.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + "{%8, %9, %10, %11}," + " %12," + " %13, %14," + " p;\n" "}\n" - : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), - "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), - "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), - "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), - "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), - "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), - "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), - "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), - "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), - "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), - "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), - "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), - "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), - "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), - "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), - "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), - "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), - "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), - "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), - "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), - "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), - "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), - "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), - "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), - "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), - "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103) - : "l"(desc_a), + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), "l"(desc_b), "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); + "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x208x64_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x16x64_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x208x64 TN F32+=E4M3*E5M2 +// SPARSE GMMA 64x32x64 TN S32+=U8*U8 template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, GMMA::SparseSel spsel = GMMA::SparseSel::Zero > -struct GMMA_64x208x64_F32E4M3E5M2_RS_TN +struct GMMA_64x32x64_S32U8U8_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; using ERegisters = uint32_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = float[104]; + using CRegisters = uint32_t[16]; CUTE_HOST_DEVICE static void - fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, uint64_t const& desc_b, - float & d000, float & d001, float & d002, float & d003, - float & d004, float & d005, float & d006, float & d007, - float & d008, float & d009, float & d010, float & d011, - float & d012, float & d013, float & d014, float & d015, - float & d016, float & d017, float & d018, float & d019, - float & d020, float & d021, float & d022, float & d023, - float & d024, float & d025, float & d026, float & d027, - float & d028, float & d029, float & d030, float & d031, - float & d032, float & d033, float & d034, float & d035, - float & d036, float & d037, float & d038, float & d039, - float & d040, float & d041, float & d042, float & d043, - float & d044, float & d045, float & d046, float & d047, - float & d048, float & d049, float & d050, float & d051, - float & d052, float & d053, float & d054, float & d055, - float & d056, float & d057, float & d058, float & d059, - float & d060, float & d061, float & d062, float & d063, - float & d064, float & d065, float & d066, float & d067, - float & d068, float & d069, float & d070, float & d071, - float & d072, float & d073, float & d074, float & d075, - float & d076, float & d077, float & d078, float & d079, - float & d080, float & d081, float & d082, float & d083, - float & d084, float & d085, float & d086, float & d087, - float & d088, float & d089, float & d090, float & d091, - float & d092, float & d093, float & d094, float & d095, - float & d096, float & d097, float & d098, float & d099, - float & d100, float & d101, float & d102, float & d103, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, uint32_t const& e, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %111, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n208k64.f32.e4m3.e5m2 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87, " - " %88, %89, %90, %91, %92, %93, %94, %95, " - " %96, %97, %98, %99, %100, %101, %102, %103}," - "{%104, %105, %106, %107}," - " %108," - " %109, %110," - " p, %112, %113;\n" + "setp.ne.b32 p, %23, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n32k64.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + "{%16, %17, %18, %19}," + " %20," + " %21, %22," + " p;\n" "}\n" - : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), - "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), - "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), - "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), - "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), - "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), - "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), - "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), - "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), - "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), - "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), - "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), - "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), - "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), - "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), - "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), - "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), - "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), - "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), - "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), - "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), - "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), - "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), - "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), - "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), - "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103) - : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), "l"(desc_b), "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); + "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x208x64_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x32x64_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x224x64 TN F16+=E4M3*E5M2 +// SPARSE GMMA 64x32x64 TN S32+=U8*U8 template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, GMMA::SparseSel spsel = GMMA::SparseSel::Zero > -struct GMMA_64x224x64_F16E4M3E5M2_SS_TN +struct GMMA_64x32x64_S32U8U8_RS_TN_SATURATE { using DRegisters = void; - using ARegisters = uint64_t[1]; + using ARegisters = uint32_t[4]; using ERegisters = uint32_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[56]; + using CRegisters = uint32_t[16]; CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, uint64_t const& desc_b, uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, - uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, - uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, uint32_t const& e, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %60, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n224k64.f16.e4m3.e5m2 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55}," - " %56," - " %57," - " %58, %59," - " p, %61, %62;\n" + "setp.ne.b32 p, %23, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n32k64.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + "{%16, %17, %18, %19}," + " %20," + " %21, %22," + " p;\n" "}\n" : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), - "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), - "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), - "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) - : "l"(desc_a), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), "l"(desc_b), "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); + "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x224x64_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x32x64_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x224x64 TN F16+=E4M3*E5M2 +// SPARSE GMMA 64x64x64 TN S32+=U8*U8 template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, GMMA::SparseSel spsel = GMMA::SparseSel::Zero > -struct GMMA_64x224x64_F16E4M3E5M2_RS_TN +struct GMMA_64x64x64_S32U8U8_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; using ERegisters = uint32_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[56]; + using CRegisters = uint32_t[32]; CUTE_HOST_DEVICE static void fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, @@ -42565,32 +13253,24 @@ struct GMMA_64x224x64_F16E4M3E5M2_RS_TN uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, - uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, - uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, uint32_t const& e, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %63, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n224k64.f16.e4m3.e5m2 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55}," - "{%56, %57, %58, %59}," - " %60," - " %61, %62," - " p, %64, %65;\n" + "setp.ne.b32 p, %39, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n64k64.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + "{%32, %33, %34, %35}," + " %36," + " %37, %38," + " p;\n" "}\n" : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), @@ -42599,273 +13279,167 @@ struct GMMA_64x224x64_F16E4M3E5M2_RS_TN "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), - "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), - "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), - "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) : "r"(a00), "r"(a01), "r"(a02), "r"(a03), "l"(desc_b), "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); + "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x224x64_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x64x64_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x224x64 TN F32+=E4M3*E5M2 +// SPARSE GMMA 64x64x64 TN S32+=U8*U8 template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, GMMA::SparseSel spsel = GMMA::SparseSel::Zero > -struct GMMA_64x224x64_F32E4M3E5M2_SS_TN +struct GMMA_64x64x64_S32U8U8_RS_TN_SATURATE { using DRegisters = void; - using ARegisters = uint64_t[1]; + using ARegisters = uint32_t[4]; using ERegisters = uint32_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = float[112]; + using CRegisters = uint32_t[32]; CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, uint64_t const& desc_b, - float & d000, float & d001, float & d002, float & d003, - float & d004, float & d005, float & d006, float & d007, - float & d008, float & d009, float & d010, float & d011, - float & d012, float & d013, float & d014, float & d015, - float & d016, float & d017, float & d018, float & d019, - float & d020, float & d021, float & d022, float & d023, - float & d024, float & d025, float & d026, float & d027, - float & d028, float & d029, float & d030, float & d031, - float & d032, float & d033, float & d034, float & d035, - float & d036, float & d037, float & d038, float & d039, - float & d040, float & d041, float & d042, float & d043, - float & d044, float & d045, float & d046, float & d047, - float & d048, float & d049, float & d050, float & d051, - float & d052, float & d053, float & d054, float & d055, - float & d056, float & d057, float & d058, float & d059, - float & d060, float & d061, float & d062, float & d063, - float & d064, float & d065, float & d066, float & d067, - float & d068, float & d069, float & d070, float & d071, - float & d072, float & d073, float & d074, float & d075, - float & d076, float & d077, float & d078, float & d079, - float & d080, float & d081, float & d082, float & d083, - float & d084, float & d085, float & d086, float & d087, - float & d088, float & d089, float & d090, float & d091, - float & d092, float & d093, float & d094, float & d095, - float & d096, float & d097, float & d098, float & d099, - float & d100, float & d101, float & d102, float & d103, - float & d104, float & d105, float & d106, float & d107, - float & d108, float & d109, float & d110, float & d111, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, uint32_t const& e, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %116, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n224k64.f32.e4m3.e5m2 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87, " - " %88, %89, %90, %91, %92, %93, %94, %95, " - " %96, %97, %98, %99, %100, %101, %102, %103, " - " %104, %105, %106, %107, %108, %109, %110, %111}," - " %112," - " %113," - " %114, %115," - " p, %117, %118;\n" + "setp.ne.b32 p, %39, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n64k64.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + "{%32, %33, %34, %35}," + " %36," + " %37, %38," + " p;\n" "}\n" - : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), - "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), - "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), - "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), - "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), - "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), - "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), - "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), - "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), - "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), - "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), - "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), - "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), - "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), - "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), - "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), - "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), - "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), - "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), - "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), - "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), - "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), - "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), - "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), - "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), - "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), - "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), - "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111) - : "l"(desc_a), + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), "l"(desc_b), "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); + "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x224x64_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x64x64_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x224x64 TN F32+=E4M3*E5M2 +// SPARSE GMMA 64x96x64 TN S32+=U8*U8 template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, GMMA::SparseSel spsel = GMMA::SparseSel::Zero > -struct GMMA_64x224x64_F32E4M3E5M2_RS_TN +struct GMMA_64x96x64_S32U8U8_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; using ERegisters = uint32_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = float[112]; + using CRegisters = uint32_t[48]; CUTE_HOST_DEVICE static void - fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, uint64_t const& desc_b, - float & d000, float & d001, float & d002, float & d003, - float & d004, float & d005, float & d006, float & d007, - float & d008, float & d009, float & d010, float & d011, - float & d012, float & d013, float & d014, float & d015, - float & d016, float & d017, float & d018, float & d019, - float & d020, float & d021, float & d022, float & d023, - float & d024, float & d025, float & d026, float & d027, - float & d028, float & d029, float & d030, float & d031, - float & d032, float & d033, float & d034, float & d035, - float & d036, float & d037, float & d038, float & d039, - float & d040, float & d041, float & d042, float & d043, - float & d044, float & d045, float & d046, float & d047, - float & d048, float & d049, float & d050, float & d051, - float & d052, float & d053, float & d054, float & d055, - float & d056, float & d057, float & d058, float & d059, - float & d060, float & d061, float & d062, float & d063, - float & d064, float & d065, float & d066, float & d067, - float & d068, float & d069, float & d070, float & d071, - float & d072, float & d073, float & d074, float & d075, - float & d076, float & d077, float & d078, float & d079, - float & d080, float & d081, float & d082, float & d083, - float & d084, float & d085, float & d086, float & d087, - float & d088, float & d089, float & d090, float & d091, - float & d092, float & d093, float & d094, float & d095, - float & d096, float & d097, float & d098, float & d099, - float & d100, float & d101, float & d102, float & d103, - float & d104, float & d105, float & d106, float & d107, - float & d108, float & d109, float & d110, float & d111, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, uint32_t const& e, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %119, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n224k64.f32.e4m3.e5m2 " + "setp.ne.b32 p, %55, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n96k64.s32.u8.u8 " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " " %16, %17, %18, %19, %20, %21, %22, %23, " " %24, %25, %26, %27, %28, %29, %30, %31, " " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87, " - " %88, %89, %90, %91, %92, %93, %94, %95, " - " %96, %97, %98, %99, %100, %101, %102, %103, " - " %104, %105, %106, %107, %108, %109, %110, %111}," - "{%112, %113, %114, %115}," - " %116," - " %117, %118," - " p, %120, %121;\n" + " %40, %41, %42, %43, %44, %45, %46, %47}," + "{%48, %49, %50, %51}," + " %52," + " %53, %54," + " p;\n" "}\n" - : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), - "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), - "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), - "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), - "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), - "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), - "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), - "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), - "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), - "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), - "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), - "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), - "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), - "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), - "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), - "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), - "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), - "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), - "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), - "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), - "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), - "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), - "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), - "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), - "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), - "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), - "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), - "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111) - : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), "l"(desc_b), "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); + "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x224x64_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x96x64_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x240x64 TN F16+=E4M3*E5M2 +// SPARSE GMMA 64x96x64 TN S32+=U8*U8 template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, GMMA::SparseSel spsel = GMMA::SparseSel::Zero > -struct GMMA_64x240x64_F16E4M3E5M2_SS_TN +struct GMMA_64x96x64_S32U8U8_RS_TN_SATURATE { using DRegisters = void; - using ARegisters = uint64_t[1]; + using ARegisters = uint32_t[4]; using ERegisters = uint32_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[60]; + using CRegisters = uint32_t[48]; CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, uint64_t const& desc_b, uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, @@ -42879,30 +13453,26 @@ struct GMMA_64x240x64_F16E4M3E5M2_SS_TN uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, - uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, - uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, - uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, uint32_t const& e, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %64, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n240k64.f16.e4m3.e5m2 " + "setp.ne.b32 p, %55, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n96k64.s32.u8.u8.satfinite " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " " %16, %17, %18, %19, %20, %21, %22, %23, " " %24, %25, %26, %27, %28, %29, %30, %31, " " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59}," - " %60," - " %61," - " %62, %63," - " p, %65, %66;\n" + " %40, %41, %42, %43, %44, %45, %46, %47}," + "{%48, %49, %50, %51}," + " %52," + " %53, %54," + " p;\n" "}\n" : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), @@ -42915,37 +13485,30 @@ struct GMMA_64x240x64_F16E4M3E5M2_SS_TN "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), - "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), - "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), - "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), - "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59) - : "l"(desc_a), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), "l"(desc_b), "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); + "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x240x64_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x96x64_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x240x64 TN F16+=E4M3*E5M2 +// SPARSE GMMA 64x128x64 TN S32+=U8*U8 template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, GMMA::SparseSel spsel = GMMA::SparseSel::Zero > -struct GMMA_64x240x64_F16E4M3E5M2_RS_TN +struct GMMA_64x128x64_S32U8U8_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; using ERegisters = uint32_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[60]; + using CRegisters = uint32_t[64]; CUTE_HOST_DEVICE static void fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, @@ -42965,15 +13528,17 @@ struct GMMA_64x240x64_F16E4M3E5M2_RS_TN uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, uint32_t const& e, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %67, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n240k64.f16.e4m3.e5m2 " + "setp.ne.b32 p, %71, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n128k64.s32.u8.u8 " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " " %16, %17, %18, %19, %20, %21, %22, %23, " @@ -42981,11 +13546,11 @@ struct GMMA_64x240x64_F16E4M3E5M2_RS_TN " %32, %33, %34, %35, %36, %37, %38, %39, " " %40, %41, %42, %43, %44, %45, %46, %47, " " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59}," - "{%60, %61, %62, %63}," - " %64," - " %65, %66," - " p, %68, %69;\n" + " %56, %57, %58, %59, %60, %61, %62, %63}," + "{%64, %65, %66, %67}," + " %68," + " %69, %70," + " p;\n" "}\n" : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), @@ -43001,197 +13566,61 @@ struct GMMA_64x240x64_F16E4M3E5M2_RS_TN "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), - "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59) + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) : "r"(a00), "r"(a01), "r"(a02), "r"(a03), "l"(desc_b), "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x240x64_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x240x64 TN F32+=E4M3*E5M2 -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x240x64_F32E4M3E5M2_SS_TN -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = float[120]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - float & d000, float & d001, float & d002, float & d003, - float & d004, float & d005, float & d006, float & d007, - float & d008, float & d009, float & d010, float & d011, - float & d012, float & d013, float & d014, float & d015, - float & d016, float & d017, float & d018, float & d019, - float & d020, float & d021, float & d022, float & d023, - float & d024, float & d025, float & d026, float & d027, - float & d028, float & d029, float & d030, float & d031, - float & d032, float & d033, float & d034, float & d035, - float & d036, float & d037, float & d038, float & d039, - float & d040, float & d041, float & d042, float & d043, - float & d044, float & d045, float & d046, float & d047, - float & d048, float & d049, float & d050, float & d051, - float & d052, float & d053, float & d054, float & d055, - float & d056, float & d057, float & d058, float & d059, - float & d060, float & d061, float & d062, float & d063, - float & d064, float & d065, float & d066, float & d067, - float & d068, float & d069, float & d070, float & d071, - float & d072, float & d073, float & d074, float & d075, - float & d076, float & d077, float & d078, float & d079, - float & d080, float & d081, float & d082, float & d083, - float & d084, float & d085, float & d086, float & d087, - float & d088, float & d089, float & d090, float & d091, - float & d092, float & d093, float & d094, float & d095, - float & d096, float & d097, float & d098, float & d099, - float & d100, float & d101, float & d102, float & d103, - float & d104, float & d105, float & d106, float & d107, - float & d108, float & d109, float & d110, float & d111, - float & d112, float & d113, float & d114, float & d115, - float & d116, float & d117, float & d118, float & d119, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %124, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n240k64.f32.e4m3.e5m2 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87, " - " %88, %89, %90, %91, %92, %93, %94, %95, " - " %96, %97, %98, %99, %100, %101, %102, %103, " - " %104, %105, %106, %107, %108, %109, %110, %111, " - " %112, %113, %114, %115, %116, %117, %118, %119}," - " %120," - " %121," - " %122, %123," - " p, %125, %126;\n" - "}\n" - : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), - "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), - "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), - "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), - "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), - "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), - "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), - "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), - "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), - "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), - "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), - "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), - "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), - "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), - "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), - "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), - "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), - "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), - "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), - "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), - "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), - "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), - "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), - "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), - "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), - "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), - "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), - "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), - "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), - "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119) - : "l"(desc_a), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); + "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x240x64_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x128x64_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x240x64 TN F32+=E4M3*E5M2 +// SPARSE GMMA 64x128x64 TN S32+=U8*U8 template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, GMMA::SparseSel spsel = GMMA::SparseSel::Zero > -struct GMMA_64x240x64_F32E4M3E5M2_RS_TN +struct GMMA_64x128x64_S32U8U8_RS_TN_SATURATE { using DRegisters = void; using ARegisters = uint32_t[4]; using ERegisters = uint32_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = float[120]; + using CRegisters = uint32_t[64]; CUTE_HOST_DEVICE static void - fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, uint64_t const& desc_b, - float & d000, float & d001, float & d002, float & d003, - float & d004, float & d005, float & d006, float & d007, - float & d008, float & d009, float & d010, float & d011, - float & d012, float & d013, float & d014, float & d015, - float & d016, float & d017, float & d018, float & d019, - float & d020, float & d021, float & d022, float & d023, - float & d024, float & d025, float & d026, float & d027, - float & d028, float & d029, float & d030, float & d031, - float & d032, float & d033, float & d034, float & d035, - float & d036, float & d037, float & d038, float & d039, - float & d040, float & d041, float & d042, float & d043, - float & d044, float & d045, float & d046, float & d047, - float & d048, float & d049, float & d050, float & d051, - float & d052, float & d053, float & d054, float & d055, - float & d056, float & d057, float & d058, float & d059, - float & d060, float & d061, float & d062, float & d063, - float & d064, float & d065, float & d066, float & d067, - float & d068, float & d069, float & d070, float & d071, - float & d072, float & d073, float & d074, float & d075, - float & d076, float & d077, float & d078, float & d079, - float & d080, float & d081, float & d082, float & d083, - float & d084, float & d085, float & d086, float & d087, - float & d088, float & d089, float & d090, float & d091, - float & d092, float & d093, float & d094, float & d095, - float & d096, float & d097, float & d098, float & d099, - float & d100, float & d101, float & d102, float & d103, - float & d104, float & d105, float & d106, float & d107, - float & d108, float & d109, float & d110, float & d111, - float & d112, float & d113, float & d114, float & d115, - float & d116, float & d117, float & d118, float & d119, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, uint32_t const& e, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %127, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n240k64.f32.e4m3.e5m2 " + "setp.ne.b32 p, %71, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n128k64.s32.u8.u8.satfinite " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " " %16, %17, %18, %19, %20, %21, %22, %23, " @@ -43199,78 +13628,54 @@ struct GMMA_64x240x64_F32E4M3E5M2_RS_TN " %32, %33, %34, %35, %36, %37, %38, %39, " " %40, %41, %42, %43, %44, %45, %46, %47, " " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87, " - " %88, %89, %90, %91, %92, %93, %94, %95, " - " %96, %97, %98, %99, %100, %101, %102, %103, " - " %104, %105, %106, %107, %108, %109, %110, %111, " - " %112, %113, %114, %115, %116, %117, %118, %119}," - "{%120, %121, %122, %123}," - " %124," - " %125, %126," - " p, %128, %129;\n" + " %56, %57, %58, %59, %60, %61, %62, %63}," + "{%64, %65, %66, %67}," + " %68," + " %69, %70," + " p;\n" "}\n" - : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), - "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), - "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), - "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), - "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), - "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), - "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), - "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), - "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), - "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), - "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), - "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), - "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), - "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), - "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), - "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), - "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), - "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), - "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), - "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), - "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), - "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), - "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), - "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), - "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), - "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), - "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), - "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), - "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), - "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119) - : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), "l"(desc_b), "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); + "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x240x64_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x128x64_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -// SPARSE GMMA 64x256x64 TN F16+=E4M3*E5M2 +// SPARSE GMMA 64x192x64 TN S32+=U8*U8 template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, GMMA::SparseSel spsel = GMMA::SparseSel::Zero > -struct GMMA_64x256x64_F16E4M3E5M2_SS_TN +struct GMMA_64x192x64_S32U8U8_RS_TN { using DRegisters = void; - using ARegisters = uint64_t[1]; + using ARegisters = uint32_t[4]; using ERegisters = uint32_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[64]; + using CRegisters = uint32_t[96]; CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, uint64_t const& desc_b, uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, @@ -43288,15 +13693,24 @@ struct GMMA_64x256x64_F16E4M3E5M2_SS_TN uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + uint32_t & d88, uint32_t & d89, uint32_t & d90, uint32_t & d91, + uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95, uint32_t const& e, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %68, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n256k64.f16.e4m3.e5m2 " + "setp.ne.b32 p, %103, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n192k64.s32.u8.u8 " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " " %16, %17, %18, %19, %20, %21, %22, %23, " @@ -43304,11 +13718,15 @@ struct GMMA_64x256x64_F16E4M3E5M2_SS_TN " %32, %33, %34, %35, %36, %37, %38, %39, " " %40, %41, %42, %43, %44, %45, %46, %47, " " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63}," - " %64," - " %65," - " %66, %67," - " p, %69, %70;\n" + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + "{%96, %97, %98, %99}," + " %100," + " %101, %102," + " p;\n" "}\n" : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), @@ -43325,32 +13743,38 @@ struct GMMA_64x256x64_F16E4M3E5M2_SS_TN "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), - "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) - : "l"(desc_a), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87), + "+r"(d88), "+r"(d89), "+r"(d90), "+r"(d91), + "+r"(d92), "+r"(d93), "+r"(d94), "+r"(d95) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), "l"(desc_b), "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); + "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x256x64_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x192x64_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// SPARSE GMMA 64x256x64 TN F16+=E4M3*E5M2 +// SPARSE GMMA 64x192x64 TN S32+=U8*U8 template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, GMMA::SparseSel spsel = GMMA::SparseSel::Zero > -struct GMMA_64x256x64_F16E4M3E5M2_RS_TN +struct GMMA_64x192x64_S32U8U8_RS_TN_SATURATE { using DRegisters = void; using ARegisters = uint32_t[4]; using ERegisters = uint32_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[64]; + using CRegisters = uint32_t[96]; CUTE_HOST_DEVICE static void fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, @@ -43371,15 +13795,24 @@ struct GMMA_64x256x64_F16E4M3E5M2_RS_TN uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + uint32_t & d88, uint32_t & d89, uint32_t & d90, uint32_t & d91, + uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95, uint32_t const& e, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %71, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n256k64.f16.e4m3.e5m2 " + "setp.ne.b32 p, %103, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n192k64.s32.u8.u8.satfinite " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " " %16, %17, %18, %19, %20, %21, %22, %23, " @@ -43387,11 +13820,15 @@ struct GMMA_64x256x64_F16E4M3E5M2_RS_TN " %32, %33, %34, %35, %36, %37, %38, %39, " " %40, %41, %42, %43, %44, %45, %46, %47, " " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63}," - "{%64, %65, %66, %67}," - " %68," - " %69, %70," - " p, %72, %73;\n" + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + "{%96, %97, %98, %99}," + " %100," + " %101, %102," + " p;\n" "}\n" : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), @@ -43408,200 +13845,206 @@ struct GMMA_64x256x64_F16E4M3E5M2_RS_TN "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), - "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87), + "+r"(d88), "+r"(d89), "+r"(d90), "+r"(d91), + "+r"(d92), "+r"(d93), "+r"(d94), "+r"(d95) : "r"(a00), "r"(a01), "r"(a02), "r"(a03), "l"(desc_b), "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); + "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x256x64_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x192x64_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// SPARSE GMMA 64x256x64 TN F32+=E4M3*E5M2 +// SPARSE GMMA 64x256x64 TN S32+=U8*U8 template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, GMMA::SparseSel spsel = GMMA::SparseSel::Zero > -struct GMMA_64x256x64_F32E4M3E5M2_SS_TN +struct GMMA_64x256x64_S32U8U8_RS_TN { using DRegisters = void; - using ARegisters = uint64_t[1]; + using ARegisters = uint32_t[4]; using ERegisters = uint32_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = float[128]; + using CRegisters = uint32_t[128]; CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, uint64_t const& desc_b, - float & d000, float & d001, float & d002, float & d003, - float & d004, float & d005, float & d006, float & d007, - float & d008, float & d009, float & d010, float & d011, - float & d012, float & d013, float & d014, float & d015, - float & d016, float & d017, float & d018, float & d019, - float & d020, float & d021, float & d022, float & d023, - float & d024, float & d025, float & d026, float & d027, - float & d028, float & d029, float & d030, float & d031, - float & d032, float & d033, float & d034, float & d035, - float & d036, float & d037, float & d038, float & d039, - float & d040, float & d041, float & d042, float & d043, - float & d044, float & d045, float & d046, float & d047, - float & d048, float & d049, float & d050, float & d051, - float & d052, float & d053, float & d054, float & d055, - float & d056, float & d057, float & d058, float & d059, - float & d060, float & d061, float & d062, float & d063, - float & d064, float & d065, float & d066, float & d067, - float & d068, float & d069, float & d070, float & d071, - float & d072, float & d073, float & d074, float & d075, - float & d076, float & d077, float & d078, float & d079, - float & d080, float & d081, float & d082, float & d083, - float & d084, float & d085, float & d086, float & d087, - float & d088, float & d089, float & d090, float & d091, - float & d092, float & d093, float & d094, float & d095, - float & d096, float & d097, float & d098, float & d099, - float & d100, float & d101, float & d102, float & d103, - float & d104, float & d105, float & d106, float & d107, - float & d108, float & d109, float & d110, float & d111, - float & d112, float & d113, float & d114, float & d115, - float & d116, float & d117, float & d118, float & d119, - float & d120, float & d121, float & d122, float & d123, - float & d124, float & d125, float & d126, float & d127, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + uint32_t & d120, uint32_t & d121, uint32_t & d122, uint32_t & d123, + uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127, uint32_t const& e, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %132, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n256k64.f32.e4m3.e5m2 " + "setp.ne.b32 p, %135, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n256k64.s32.u8.u8 " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " " %16, %17, %18, %19, %20, %21, %22, %23, " " %24, %25, %26, %27, %28, %29, %30, %31, " " %32, %33, %34, %35, %36, %37, %38, %39, " " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87, " - " %88, %89, %90, %91, %92, %93, %94, %95, " - " %96, %97, %98, %99, %100, %101, %102, %103, " - " %104, %105, %106, %107, %108, %109, %110, %111, " - " %112, %113, %114, %115, %116, %117, %118, %119, " - " %120, %121, %122, %123, %124, %125, %126, %127}," - " %128," - " %129," - " %130, %131," - " p, %133, %134;\n" - "}\n" - : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), - "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), - "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), - "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), - "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), - "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), - "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), - "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), - "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), - "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), - "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), - "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), - "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), - "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), - "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), - "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), - "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), - "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), - "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), - "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), - "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), - "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), - "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), - "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), - "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), - "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), - "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), - "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), - "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), - "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), - "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123), - "+f"(d124), "+f"(d125), "+f"(d126), "+f"(d127) - : "l"(desc_a), + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + "{%128, %129, %130, %131}," + " %132," + " %133, %134," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119), + "+r"(d120), "+r"(d121), "+r"(d122), "+r"(d123), + "+r"(d124), "+r"(d125), "+r"(d126), "+r"(d127) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), "l"(desc_b), "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); + "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x256x64_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x256x64_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// SPARSE GMMA 64x256x64 TN F32+=E4M3*E5M2 +// SPARSE GMMA 64x256x64 TN S32+=U8*U8 template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, GMMA::SparseSel spsel = GMMA::SparseSel::Zero > -struct GMMA_64x256x64_F32E4M3E5M2_RS_TN +struct GMMA_64x256x64_S32U8U8_RS_TN_SATURATE { using DRegisters = void; using ARegisters = uint32_t[4]; using ERegisters = uint32_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = float[128]; + using CRegisters = uint32_t[128]; CUTE_HOST_DEVICE static void fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, uint64_t const& desc_b, - float & d000, float & d001, float & d002, float & d003, - float & d004, float & d005, float & d006, float & d007, - float & d008, float & d009, float & d010, float & d011, - float & d012, float & d013, float & d014, float & d015, - float & d016, float & d017, float & d018, float & d019, - float & d020, float & d021, float & d022, float & d023, - float & d024, float & d025, float & d026, float & d027, - float & d028, float & d029, float & d030, float & d031, - float & d032, float & d033, float & d034, float & d035, - float & d036, float & d037, float & d038, float & d039, - float & d040, float & d041, float & d042, float & d043, - float & d044, float & d045, float & d046, float & d047, - float & d048, float & d049, float & d050, float & d051, - float & d052, float & d053, float & d054, float & d055, - float & d056, float & d057, float & d058, float & d059, - float & d060, float & d061, float & d062, float & d063, - float & d064, float & d065, float & d066, float & d067, - float & d068, float & d069, float & d070, float & d071, - float & d072, float & d073, float & d074, float & d075, - float & d076, float & d077, float & d078, float & d079, - float & d080, float & d081, float & d082, float & d083, - float & d084, float & d085, float & d086, float & d087, - float & d088, float & d089, float & d090, float & d091, - float & d092, float & d093, float & d094, float & d095, - float & d096, float & d097, float & d098, float & d099, - float & d100, float & d101, float & d102, float & d103, - float & d104, float & d105, float & d106, float & d107, - float & d108, float & d109, float & d110, float & d111, - float & d112, float & d113, float & d114, float & d115, - float & d116, float & d117, float & d118, float & d119, - float & d120, float & d121, float & d122, float & d123, - float & d124, float & d125, float & d126, float & d127, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + uint32_t & d120, uint32_t & d121, uint32_t & d122, uint32_t & d123, + uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127, uint32_t const& e, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" "setp.ne.b32 p, %135, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n256k64.f32.e4m3.e5m2 " + "wgmma.mma_async.sp.sync.aligned.m64n256k64.s32.u8.u8.satfinite " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " " %16, %17, %18, %19, %20, %21, %22, %23, " @@ -43621,59 +14064,59 @@ struct GMMA_64x256x64_F32E4M3E5M2_RS_TN "{%128, %129, %130, %131}," " %132," " %133, %134," - " p, %136, %137;\n" + " p;\n" "}\n" - : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), - "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), - "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), - "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), - "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), - "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), - "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), - "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), - "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), - "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), - "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), - "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), - "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), - "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), - "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), - "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), - "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), - "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), - "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), - "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), - "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), - "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), - "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), - "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), - "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), - "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), - "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), - "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), - "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), - "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), - "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123), - "+f"(d124), "+f"(d125), "+f"(d126), "+f"(d127) + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119), + "+r"(d120), "+r"(d121), "+r"(d122), "+r"(d123), + "+r"(d124), "+r"(d125), "+r"(d126), "+r"(d127) : "r"(a000), "r"(a001), "r"(a002), "r"(a003), "l"(desc_b), "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); + "r"(int32_t(scale_D))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x256x64_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x256x64_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// SPARSE GMMA 64x8x64 TN F16+=E5M2*E4M3 +// SPARSE GMMA 64x8x64 TN F16+=E4M3*E4M3 template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, GMMA::SparseSel spsel = GMMA::SparseSel::Zero > -struct GMMA_64x8x64_F16E5M2E4M3_SS_TN +struct GMMA_64x8x64_F16E4M3E4M3_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -43689,11 +14132,12 @@ struct GMMA_64x8x64_F16E5M2E4M3_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" "setp.ne.b32 p, %6, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n8k64.f16.e5m2.e4m3 " + "wgmma.mma_async.sp.sync.aligned.m64n8k64.f16.e4m3.e4m3 " "{%0, %1}," " %2," " %3," @@ -43706,20 +14150,20 @@ struct GMMA_64x8x64_F16E5M2E4M3_SS_TN "r"(e), "n"(int32_t(spsel)), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x8x64_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x8x64_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// SPARSE GMMA 64x8x64 TN F16+=E5M2*E4M3 +// SPARSE GMMA 64x8x64 TN F16+=E4M3*E4M3 template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, GMMA::SparseSel spsel = GMMA::SparseSel::Zero > -struct GMMA_64x8x64_F16E5M2E4M3_RS_TN +struct GMMA_64x8x64_F16E4M3E4M3_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -43735,11 +14179,12 @@ struct GMMA_64x8x64_F16E5M2E4M3_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" "setp.ne.b32 p, %9, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n8k64.f16.e5m2.e4m3 " + "wgmma.mma_async.sp.sync.aligned.m64n8k64.f16.e4m3.e4m3 " "{%0, %1}," "{%2, %3, %4, %5}," " %6," @@ -43752,20 +14197,20 @@ struct GMMA_64x8x64_F16E5M2E4M3_RS_TN "r"(e), "n"(int32_t(spsel)), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x8x64_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x8x64_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// SPARSE GMMA 64x8x64 TN F32+=E5M2*E4M3 +// SPARSE GMMA 64x8x64 TN F32+=E4M3*E4M3 template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, GMMA::SparseSel spsel = GMMA::SparseSel::Zero > -struct GMMA_64x8x64_F32E5M2E4M3_SS_TN +struct GMMA_64x8x64_F32E4M3E4M3_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -43781,11 +14226,12 @@ struct GMMA_64x8x64_F32E5M2E4M3_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" "setp.ne.b32 p, %8, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n8k64.f32.e5m2.e4m3 " + "wgmma.mma_async.sp.sync.aligned.m64n8k64.f32.e4m3.e4m3 " "{%0, %1, %2, %3}," " %4," " %5," @@ -43798,20 +14244,20 @@ struct GMMA_64x8x64_F32E5M2E4M3_SS_TN "r"(e), "n"(int32_t(spsel)), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x8x64_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x8x64_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// SPARSE GMMA 64x8x64 TN F32+=E5M2*E4M3 +// SPARSE GMMA 64x8x64 TN F32+=E4M3*E4M3 template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, GMMA::SparseSel spsel = GMMA::SparseSel::Zero > -struct GMMA_64x8x64_F32E5M2E4M3_RS_TN +struct GMMA_64x8x64_F32E4M3E4M3_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -43827,11 +14273,12 @@ struct GMMA_64x8x64_F32E5M2E4M3_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" "setp.ne.b32 p, %11, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n8k64.f32.e5m2.e4m3 " + "wgmma.mma_async.sp.sync.aligned.m64n8k64.f32.e4m3.e4m3 " "{%0, %1, %2, %3}," "{%4, %5, %6, %7}," " %8," @@ -43844,20 +14291,20 @@ struct GMMA_64x8x64_F32E5M2E4M3_RS_TN "r"(e), "n"(int32_t(spsel)), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x8x64_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x8x64_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// SPARSE GMMA 64x16x64 TN F16+=E5M2*E4M3 +// SPARSE GMMA 64x16x64 TN F16+=E4M3*E4M3 template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, GMMA::SparseSel spsel = GMMA::SparseSel::Zero > -struct GMMA_64x16x64_F16E5M2E4M3_SS_TN +struct GMMA_64x16x64_F16E4M3E4M3_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -43873,11 +14320,12 @@ struct GMMA_64x16x64_F16E5M2E4M3_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" "setp.ne.b32 p, %8, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n16k64.f16.e5m2.e4m3 " + "wgmma.mma_async.sp.sync.aligned.m64n16k64.f16.e4m3.e4m3 " "{%0, %1, %2, %3}," " %4," " %5," @@ -43890,20 +14338,20 @@ struct GMMA_64x16x64_F16E5M2E4M3_SS_TN "r"(e), "n"(int32_t(spsel)), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x16x64_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x16x64_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// SPARSE GMMA 64x16x64 TN F16+=E5M2*E4M3 +// SPARSE GMMA 64x16x64 TN F16+=E4M3*E4M3 template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, GMMA::SparseSel spsel = GMMA::SparseSel::Zero > -struct GMMA_64x16x64_F16E5M2E4M3_RS_TN +struct GMMA_64x16x64_F16E4M3E4M3_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -43919,11 +14367,12 @@ struct GMMA_64x16x64_F16E5M2E4M3_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" "setp.ne.b32 p, %11, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n16k64.f16.e5m2.e4m3 " + "wgmma.mma_async.sp.sync.aligned.m64n16k64.f16.e4m3.e4m3 " "{%0, %1, %2, %3}," "{%4, %5, %6, %7}," " %8," @@ -43936,20 +14385,20 @@ struct GMMA_64x16x64_F16E5M2E4M3_RS_TN "r"(e), "n"(int32_t(spsel)), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x16x64_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x16x64_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// SPARSE GMMA 64x16x64 TN F32+=E5M2*E4M3 +// SPARSE GMMA 64x16x64 TN F32+=E4M3*E4M3 template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, GMMA::SparseSel spsel = GMMA::SparseSel::Zero > -struct GMMA_64x16x64_F32E5M2E4M3_SS_TN +struct GMMA_64x16x64_F32E4M3E4M3_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -43966,11 +14415,12 @@ struct GMMA_64x16x64_F32E5M2E4M3_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" "setp.ne.b32 p, %12, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n16k64.f32.e5m2.e4m3 " + "wgmma.mma_async.sp.sync.aligned.m64n16k64.f32.e4m3.e4m3 " "{%0, %1, %2, %3, %4, %5, %6, %7}," " %8," " %9," @@ -43984,20 +14434,20 @@ struct GMMA_64x16x64_F32E5M2E4M3_SS_TN "r"(e), "n"(int32_t(spsel)), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x16x64_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x16x64_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// SPARSE GMMA 64x16x64 TN F32+=E5M2*E4M3 +// SPARSE GMMA 64x16x64 TN F32+=E4M3*E4M3 template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, GMMA::SparseSel spsel = GMMA::SparseSel::Zero > -struct GMMA_64x16x64_F32E5M2E4M3_RS_TN +struct GMMA_64x16x64_F32E4M3E4M3_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -44014,11 +14464,12 @@ struct GMMA_64x16x64_F32E5M2E4M3_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" "setp.ne.b32 p, %15, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n16k64.f32.e5m2.e4m3 " + "wgmma.mma_async.sp.sync.aligned.m64n16k64.f32.e4m3.e4m3 " "{%0, %1, %2, %3, %4, %5, %6, %7}," "{%8, %9, %10, %11}," " %12," @@ -44032,20 +14483,20 @@ struct GMMA_64x16x64_F32E5M2E4M3_RS_TN "r"(e), "n"(int32_t(spsel)), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x16x64_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x16x64_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// SPARSE GMMA 64x32x64 TN F16+=E5M2*E4M3 +// SPARSE GMMA 64x32x64 TN F16+=E4M3*E4M3 template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, GMMA::SparseSel spsel = GMMA::SparseSel::Zero > -struct GMMA_64x32x64_F16E5M2E4M3_SS_TN +struct GMMA_64x32x64_F16E4M3E4M3_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -44062,11 +14513,12 @@ struct GMMA_64x32x64_F16E5M2E4M3_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" "setp.ne.b32 p, %12, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n32k64.f16.e5m2.e4m3 " + "wgmma.mma_async.sp.sync.aligned.m64n32k64.f16.e4m3.e4m3 " "{%0, %1, %2, %3, %4, %5, %6, %7}," " %8," " %9," @@ -44080,20 +14532,20 @@ struct GMMA_64x32x64_F16E5M2E4M3_SS_TN "r"(e), "n"(int32_t(spsel)), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x32x64_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x32x64_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// SPARSE GMMA 64x32x64 TN F16+=E5M2*E4M3 +// SPARSE GMMA 64x32x64 TN F16+=E4M3*E4M3 template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, GMMA::SparseSel spsel = GMMA::SparseSel::Zero > -struct GMMA_64x32x64_F16E5M2E4M3_RS_TN +struct GMMA_64x32x64_F16E4M3E4M3_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -44110,11 +14562,12 @@ struct GMMA_64x32x64_F16E5M2E4M3_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" "setp.ne.b32 p, %15, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n32k64.f16.e5m2.e4m3 " + "wgmma.mma_async.sp.sync.aligned.m64n32k64.f16.e4m3.e4m3 " "{%0, %1, %2, %3, %4, %5, %6, %7}," "{%8, %9, %10, %11}," " %12," @@ -44128,20 +14581,20 @@ struct GMMA_64x32x64_F16E5M2E4M3_RS_TN "r"(e), "n"(int32_t(spsel)), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x32x64_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x32x64_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// SPARSE GMMA 64x32x64 TN F32+=E5M2*E4M3 +// SPARSE GMMA 64x32x64 TN F32+=E4M3*E4M3 template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, GMMA::SparseSel spsel = GMMA::SparseSel::Zero > -struct GMMA_64x32x64_F32E5M2E4M3_SS_TN +struct GMMA_64x32x64_F32E4M3E4M3_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -44160,11 +14613,12 @@ struct GMMA_64x32x64_F32E5M2E4M3_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" "setp.ne.b32 p, %20, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n32k64.f32.e5m2.e4m3 " + "wgmma.mma_async.sp.sync.aligned.m64n32k64.f32.e4m3.e4m3 " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15}," " %16," @@ -44181,20 +14635,20 @@ struct GMMA_64x32x64_F32E5M2E4M3_SS_TN "r"(e), "n"(int32_t(spsel)), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x32x64_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x32x64_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// SPARSE GMMA 64x32x64 TN F32+=E5M2*E4M3 +// SPARSE GMMA 64x32x64 TN F32+=E4M3*E4M3 template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, GMMA::SparseSel spsel = GMMA::SparseSel::Zero > -struct GMMA_64x32x64_F32E5M2E4M3_RS_TN +struct GMMA_64x32x64_F32E4M3E4M3_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -44213,11 +14667,12 @@ struct GMMA_64x32x64_F32E5M2E4M3_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" "setp.ne.b32 p, %23, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n32k64.f32.e5m2.e4m3 " + "wgmma.mma_async.sp.sync.aligned.m64n32k64.f32.e4m3.e4m3 " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15}," "{%16, %17, %18, %19}," @@ -44234,246 +14689,20 @@ struct GMMA_64x32x64_F32E5M2E4M3_RS_TN "r"(e), "n"(int32_t(spsel)), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x32x64_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x48x64 TN F16+=E5M2*E4M3 -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x48x64_F16E5M2E4M3_SS_TN -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[12]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %16, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n48k64.f16.e5m2.e4m3 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11}," - " %12," - " %13," - " %14, %15," - " p, %17, %18;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11) - : "l"(desc_a), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x48x64_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x48x64 TN F16+=E5M2*E4M3 -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x48x64_F16E5M2E4M3_RS_TN -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[12]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %19, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n48k64.f16.e5m2.e4m3 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11}," - "{%12, %13, %14, %15}," - " %16," - " %17, %18," - " p, %20, %21;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x48x64_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x48x64 TN F32+=E5M2*E4M3 -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x48x64_F32E5M2E4M3_SS_TN -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = float[24]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - float & d00, float & d01, float & d02, float & d03, - float & d04, float & d05, float & d06, float & d07, - float & d08, float & d09, float & d10, float & d11, - float & d12, float & d13, float & d14, float & d15, - float & d16, float & d17, float & d18, float & d19, - float & d20, float & d21, float & d22, float & d23, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %28, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n48k64.f32.e5m2.e4m3 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23}," - " %24," - " %25," - " %26, %27," - " p, %29, %30;\n" - "}\n" - : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), - "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), - "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), - "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), - "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), - "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23) - : "l"(desc_a), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x48x64_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x48x64 TN F32+=E5M2*E4M3 -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x48x64_F32E5M2E4M3_RS_TN -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = float[24]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, - uint64_t const& desc_b, - float & d00, float & d01, float & d02, float & d03, - float & d04, float & d05, float & d06, float & d07, - float & d08, float & d09, float & d10, float & d11, - float & d12, float & d13, float & d14, float & d15, - float & d16, float & d17, float & d18, float & d19, - float & d20, float & d21, float & d22, float & d23, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %31, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n48k64.f32.e5m2.e4m3 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23}," - "{%24, %25, %26, %27}," - " %28," - " %29, %30," - " p, %32, %33;\n" - "}\n" - : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), - "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), - "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), - "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), - "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), - "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x48x64_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x32x64_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -// SPARSE GMMA 64x64x64 TN F16+=E5M2*E4M3 +// SPARSE GMMA 64x64x64 TN F16+=E4M3*E4M3 template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, GMMA::SparseSel spsel = GMMA::SparseSel::Zero > -struct GMMA_64x64x64_F16E5M2E4M3_SS_TN +struct GMMA_64x64x64_F16E4M3E4M3_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -44492,11 +14721,12 @@ struct GMMA_64x64x64_F16E5M2E4M3_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" "setp.ne.b32 p, %20, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n64k64.f16.e5m2.e4m3 " + "wgmma.mma_async.sp.sync.aligned.m64n64k64.f16.e4m3.e4m3 " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15}," " %16," @@ -44513,20 +14743,20 @@ struct GMMA_64x64x64_F16E5M2E4M3_SS_TN "r"(e), "n"(int32_t(spsel)), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x64x64_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x64x64_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// SPARSE GMMA 64x64x64 TN F16+=E5M2*E4M3 +// SPARSE GMMA 64x64x64 TN F16+=E4M3*E4M3 template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, GMMA::SparseSel spsel = GMMA::SparseSel::Zero > -struct GMMA_64x64x64_F16E5M2E4M3_RS_TN +struct GMMA_64x64x64_F16E4M3E4M3_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -44545,11 +14775,12 @@ struct GMMA_64x64x64_F16E5M2E4M3_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" "setp.ne.b32 p, %23, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n64k64.f16.e5m2.e4m3 " + "wgmma.mma_async.sp.sync.aligned.m64n64k64.f16.e4m3.e4m3 " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15}," "{%16, %17, %18, %19}," @@ -44566,20 +14797,20 @@ struct GMMA_64x64x64_F16E5M2E4M3_RS_TN "r"(e), "n"(int32_t(spsel)), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x64x64_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x64x64_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// SPARSE GMMA 64x64x64 TN F32+=E5M2*E4M3 +// SPARSE GMMA 64x64x64 TN F32+=E4M3*E4M3 template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, GMMA::SparseSel spsel = GMMA::SparseSel::Zero > -struct GMMA_64x64x64_F32E5M2E4M3_SS_TN +struct GMMA_64x64x64_F32E4M3E4M3_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -44602,11 +14833,12 @@ struct GMMA_64x64x64_F32E5M2E4M3_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" "setp.ne.b32 p, %36, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n64k64.f32.e5m2.e4m3 " + "wgmma.mma_async.sp.sync.aligned.m64n64k64.f32.e4m3.e4m3 " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " " %16, %17, %18, %19, %20, %21, %22, %23, " @@ -44629,20 +14861,20 @@ struct GMMA_64x64x64_F32E5M2E4M3_SS_TN "r"(e), "n"(int32_t(spsel)), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x64x64_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x64x64_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// SPARSE GMMA 64x64x64 TN F32+=E5M2*E4M3 +// SPARSE GMMA 64x64x64 TN F32+=E4M3*E4M3 template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, GMMA::SparseSel spsel = GMMA::SparseSel::Zero > -struct GMMA_64x64x64_F32E5M2E4M3_RS_TN +struct GMMA_64x64x64_F32E4M3E4M3_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -44665,11 +14897,12 @@ struct GMMA_64x64x64_F32E5M2E4M3_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" "setp.ne.b32 p, %39, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n64k64.f32.e5m2.e4m3 " + "wgmma.mma_async.sp.sync.aligned.m64n64k64.f32.e4m3.e4m3 " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " " %16, %17, %18, %19, %20, %21, %22, %23, " @@ -44692,27 +14925,26 @@ struct GMMA_64x64x64_F32E5M2E4M3_RS_TN "r"(e), "n"(int32_t(spsel)), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x64x64_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x64x64_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x80x64 TN F16+=E5M2*E4M3 +// SPARSE GMMA 64x96x64 TN F16+=E4M3*E4M3 template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, GMMA::SparseSel spsel = GMMA::SparseSel::Zero > -struct GMMA_64x80x64_F16E5M2E4M3_SS_TN +struct GMMA_64x96x64_F16E4M3E4M3_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; using ERegisters = uint32_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[20]; + using CRegisters = uint32_t[24]; CUTE_HOST_DEVICE static void fma(uint64_t const& desc_a, @@ -44722,55 +14954,56 @@ struct GMMA_64x80x64_F16E5M2E4M3_SS_TN uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, uint32_t const& e, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %24, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n80k64.f16.e5m2.e4m3 " + "setp.ne.b32 p, %28, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n96k64.f16.e4m3.e4m3 " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19}," - " %20," - " %21," - " %22, %23," - " p, %25, %26;\n" + " %16, %17, %18, %19, %20, %21, %22, %23}," + " %24," + " %25," + " %26, %27," + " p, %29, %30;\n" "}\n" : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19) + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) : "l"(desc_a), "l"(desc_b), "r"(e), "n"(int32_t(spsel)), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x80x64_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x96x64_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x80x64 TN F16+=E5M2*E4M3 +// SPARSE GMMA 64x96x64 TN F16+=E4M3*E4M3 template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, GMMA::SparseSel spsel = GMMA::SparseSel::Zero > -struct GMMA_64x80x64_F16E5M2E4M3_RS_TN +struct GMMA_64x96x64_F16E4M3E4M3_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; using ERegisters = uint32_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[20]; + using CRegisters = uint32_t[24]; CUTE_HOST_DEVICE static void fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, @@ -44780,55 +15013,56 @@ struct GMMA_64x80x64_F16E5M2E4M3_RS_TN uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, uint32_t const& e, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %27, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n80k64.f16.e5m2.e4m3 " + "setp.ne.b32 p, %31, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n96k64.f16.e4m3.e4m3 " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19}," - "{%20, %21, %22, %23}," - " %24," - " %25, %26," - " p, %28, %29;\n" + " %16, %17, %18, %19, %20, %21, %22, %23}," + "{%24, %25, %26, %27}," + " %28," + " %29, %30," + " p, %32, %33;\n" "}\n" : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19) + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) : "r"(a00), "r"(a01), "r"(a02), "r"(a03), "l"(desc_b), "r"(e), "n"(int32_t(spsel)), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x80x64_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x96x64_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x80x64 TN F32+=E5M2*E4M3 +// SPARSE GMMA 64x96x64 TN F32+=E4M3*E4M3 template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, GMMA::SparseSel spsel = GMMA::SparseSel::Zero > -struct GMMA_64x80x64_F32E5M2E4M3_SS_TN +struct GMMA_64x96x64_F32E4M3E4M3_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; using ERegisters = uint32_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = float[40]; + using CRegisters = float[48]; CUTE_HOST_DEVICE static void fma(uint64_t const& desc_a, @@ -44843,24 +15077,28 @@ struct GMMA_64x80x64_F32E5M2E4M3_SS_TN float & d28, float & d29, float & d30, float & d31, float & d32, float & d33, float & d34, float & d35, float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, uint32_t const& e, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %44, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n80k64.f32.e5m2.e4m3 " + "setp.ne.b32 p, %52, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n96k64.f32.e4m3.e4m3 " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " " %16, %17, %18, %19, %20, %21, %22, %23, " " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39}," - " %40," - " %41," - " %42, %43," - " p, %45, %46;\n" + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + " %48," + " %49," + " %50, %51," + " p, %53, %54;\n" "}\n" : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), @@ -44871,34 +15109,34 @@ struct GMMA_64x80x64_F32E5M2E4M3_SS_TN "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), - "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39) + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47) : "l"(desc_a), "l"(desc_b), "r"(e), "n"(int32_t(spsel)), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x80x64_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x96x64_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x80x64 TN F32+=E5M2*E4M3 +// SPARSE GMMA 64x96x64 TN F32+=E4M3*E4M3 template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, GMMA::SparseSel spsel = GMMA::SparseSel::Zero > -struct GMMA_64x80x64_F32E5M2E4M3_RS_TN +struct GMMA_64x96x64_F32E4M3E4M3_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; using ERegisters = uint32_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = float[40]; + using CRegisters = float[48]; CUTE_HOST_DEVICE static void fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, @@ -44913,24 +15151,28 @@ struct GMMA_64x80x64_F32E5M2E4M3_RS_TN float & d28, float & d29, float & d30, float & d31, float & d32, float & d33, float & d34, float & d35, float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, uint32_t const& e, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %47, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n80k64.f32.e5m2.e4m3 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39}," - "{%40, %41, %42, %43}," - " %44," - " %45, %46," - " p, %48, %49;\n" + "setp.ne.b32 p, %55, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n96k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + "{%48, %49, %50, %51}," + " %52," + " %53, %54," + " p, %56, %57;\n" "}\n" : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), @@ -44941,33 +15183,34 @@ struct GMMA_64x80x64_F32E5M2E4M3_RS_TN "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), - "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39) + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47) : "r"(a00), "r"(a01), "r"(a02), "r"(a03), "l"(desc_b), "r"(e), "n"(int32_t(spsel)), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x80x64_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x96x64_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -// SPARSE GMMA 64x96x64 TN F16+=E5M2*E4M3 +// SPARSE GMMA 64x128x64 TN F16+=E4M3*E4M3 template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, GMMA::SparseSel spsel = GMMA::SparseSel::Zero > -struct GMMA_64x96x64_F16E5M2E4M3_SS_TN +struct GMMA_64x128x64_F16E4M3E4M3_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; using ERegisters = uint32_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[24]; + using CRegisters = uint32_t[32]; CUTE_HOST_DEVICE static void fma(uint64_t const& desc_a, @@ -44978,54 +15221,60 @@ struct GMMA_64x96x64_F16E5M2E4M3_SS_TN uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, uint32_t const& e, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %28, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n96k64.f16.e5m2.e4m3 " + "setp.ne.b32 p, %36, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n128k64.f16.e4m3.e4m3 " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23}," - " %24," - " %25," - " %26, %27," - " p, %29, %30;\n" + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + " %32," + " %33," + " %34, %35," + " p, %37, %38;\n" "}\n" : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) : "l"(desc_a), "l"(desc_b), "r"(e), "n"(int32_t(spsel)), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x96x64_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x128x64_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// SPARSE GMMA 64x96x64 TN F16+=E5M2*E4M3 +// SPARSE GMMA 64x128x64 TN F16+=E4M3*E4M3 template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, GMMA::SparseSel spsel = GMMA::SparseSel::Zero > -struct GMMA_64x96x64_F16E5M2E4M3_RS_TN +struct GMMA_64x128x64_F16E4M3E4M3_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; using ERegisters = uint32_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[24]; + using CRegisters = uint32_t[32]; CUTE_HOST_DEVICE static void fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, @@ -45036,54 +15285,60 @@ struct GMMA_64x96x64_F16E5M2E4M3_RS_TN uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, uint32_t const& e, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %31, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n96k64.f16.e5m2.e4m3 " + "setp.ne.b32 p, %39, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n128k64.f16.e4m3.e4m3 " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23}," - "{%24, %25, %26, %27}," - " %28," - " %29, %30," - " p, %32, %33;\n" + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + "{%32, %33, %34, %35}," + " %36," + " %37, %38," + " p, %40, %41;\n" "}\n" : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) : "r"(a00), "r"(a01), "r"(a02), "r"(a03), "l"(desc_b), "r"(e), "n"(int32_t(spsel)), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x96x64_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x128x64_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// SPARSE GMMA 64x96x64 TN F32+=E5M2*E4M3 +// SPARSE GMMA 64x128x64 TN F32+=E4M3*E4M3 template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, GMMA::SparseSel spsel = GMMA::SparseSel::Zero > -struct GMMA_64x96x64_F32E5M2E4M3_SS_TN +struct GMMA_64x128x64_F32E4M3E4M3_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; using ERegisters = uint32_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = float[48]; + using CRegisters = float[64]; CUTE_HOST_DEVICE static void fma(uint64_t const& desc_a, @@ -45100,25 +15355,32 @@ struct GMMA_64x96x64_F32E5M2E4M3_SS_TN float & d36, float & d37, float & d38, float & d39, float & d40, float & d41, float & d42, float & d43, float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, uint32_t const& e, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %52, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n96k64.f32.e5m2.e4m3 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47}," - " %48," - " %49," - " %50, %51," - " p, %53, %54;\n" + "setp.ne.b32 p, %68, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n128k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + " %64," + " %65," + " %66, %67," + " p, %69, %70;\n" "}\n" : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), @@ -45131,32 +15393,36 @@ struct GMMA_64x96x64_F32E5M2E4M3_SS_TN "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), - "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47) + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63) : "l"(desc_a), "l"(desc_b), "r"(e), "n"(int32_t(spsel)), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x96x64_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x128x64_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// SPARSE GMMA 64x96x64 TN F32+=E5M2*E4M3 +// SPARSE GMMA 64x128x64 TN F32+=E4M3*E4M3 template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, GMMA::SparseSel spsel = GMMA::SparseSel::Zero > -struct GMMA_64x96x64_F32E5M2E4M3_RS_TN +struct GMMA_64x128x64_F32E4M3E4M3_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; using ERegisters = uint32_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = float[48]; + using CRegisters = float[64]; CUTE_HOST_DEVICE static void fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, @@ -45173,25 +15439,32 @@ struct GMMA_64x96x64_F32E5M2E4M3_RS_TN float & d36, float & d37, float & d38, float & d39, float & d40, float & d41, float & d42, float & d43, float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, uint32_t const& e, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %55, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n96k64.f32.e5m2.e4m3 " + "setp.ne.b32 p, %71, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n128k64.f32.e4m3.e4m3 " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " " %16, %17, %18, %19, %20, %21, %22, %23, " " %24, %25, %26, %27, %28, %29, %30, %31, " " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47}," - "{%48, %49, %50, %51}," - " %52," - " %53, %54," - " p, %56, %57;\n" + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + "{%64, %65, %66, %67}," + " %68," + " %69, %70," + " p, %72, %73;\n" "}\n" : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), @@ -45204,33 +15477,36 @@ struct GMMA_64x96x64_F32E5M2E4M3_RS_TN "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), - "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47) + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63) : "r"(a00), "r"(a01), "r"(a02), "r"(a03), "l"(desc_b), "r"(e), "n"(int32_t(spsel)), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x96x64_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x128x64_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x112x64 TN F16+=E5M2*E4M3 +// SPARSE GMMA 64x192x64 TN F16+=E4M3*E4M3 template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, GMMA::SparseSel spsel = GMMA::SparseSel::Zero > -struct GMMA_64x112x64_F16E5M2E4M3_SS_TN +struct GMMA_64x192x64_F16E4M3E4M3_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; using ERegisters = uint32_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[28]; + using CRegisters = uint32_t[48]; CUTE_HOST_DEVICE static void fma(uint64_t const& desc_a, @@ -45242,23 +15518,31 @@ struct GMMA_64x112x64_F16E5M2E4M3_SS_TN uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, uint32_t const& e, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %32, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n112k64.f16.e5m2.e4m3 " + "setp.ne.b32 p, %52, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n192k64.f16.e4m3.e4m3 " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27}," - " %28," - " %29," - " %30, %31," - " p, %33, %34;\n" + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + " %48," + " %49," + " %50, %51," + " p, %53, %54;\n" "}\n" : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), @@ -45266,34 +15550,37 @@ struct GMMA_64x112x64_F16E5M2E4M3_SS_TN "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27) + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) : "l"(desc_a), "l"(desc_b), "r"(e), "n"(int32_t(spsel)), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x112x64_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x192x64_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x112x64 TN F16+=E5M2*E4M3 +// SPARSE GMMA 64x192x64 TN F16+=E4M3*E4M3 template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, GMMA::SparseSel spsel = GMMA::SparseSel::Zero > -struct GMMA_64x112x64_F16E5M2E4M3_RS_TN +struct GMMA_64x192x64_F16E4M3E4M3_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; using ERegisters = uint32_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[28]; + using CRegisters = uint32_t[48]; CUTE_HOST_DEVICE static void fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, @@ -45305,23 +15592,31 @@ struct GMMA_64x112x64_F16E5M2E4M3_RS_TN uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, uint32_t const& e, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %35, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n112k64.f16.e5m2.e4m3 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27}," - "{%28, %29, %30, %31}," - " %32," - " %33, %34," - " p, %36, %37;\n" + "setp.ne.b32 p, %55, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n192k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + "{%48, %49, %50, %51}," + " %52," + " %53, %54," + " p, %56, %57;\n" "}\n" : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), @@ -45329,34 +15624,37 @@ struct GMMA_64x112x64_F16E5M2E4M3_RS_TN "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27) + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) : "r"(a00), "r"(a01), "r"(a02), "r"(a03), "l"(desc_b), "r"(e), "n"(int32_t(spsel)), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x112x64_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x192x64_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x112x64 TN F32+=E5M2*E4M3 +// SPARSE GMMA 64x192x64 TN F32+=E4M3*E4M3 template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, GMMA::SparseSel spsel = GMMA::SparseSel::Zero > -struct GMMA_64x112x64_F32E5M2E4M3_SS_TN +struct GMMA_64x192x64_F32E4M3E4M3_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; using ERegisters = uint32_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = float[56]; + using CRegisters = float[96]; CUTE_HOST_DEVICE static void fma(uint64_t const& desc_a, @@ -45375,26 +15673,42 @@ struct GMMA_64x112x64_F32E5M2E4M3_SS_TN float & d44, float & d45, float & d46, float & d47, float & d48, float & d49, float & d50, float & d51, float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + float & d88, float & d89, float & d90, float & d91, + float & d92, float & d93, float & d94, float & d95, uint32_t const& e, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %60, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n112k64.f32.e5m2.e4m3 " + "setp.ne.b32 p, %100, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n192k64.f32.e4m3.e4m3 " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " " %16, %17, %18, %19, %20, %21, %22, %23, " " %24, %25, %26, %27, %28, %29, %30, %31, " " %32, %33, %34, %35, %36, %37, %38, %39, " " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55}," - " %56," - " %57," - " %58, %59," - " p, %61, %62;\n" + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + " %96," + " %97," + " %98, %99," + " p, %101, %102;\n" "}\n" : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), @@ -45409,34 +15723,42 @@ struct GMMA_64x112x64_F32E5M2E4M3_SS_TN "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), - "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55) + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), + "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91), + "+f"(d92), "+f"(d93), "+f"(d94), "+f"(d95) : "l"(desc_a), "l"(desc_b), "r"(e), "n"(int32_t(spsel)), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x112x64_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x192x64_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x112x64 TN F32+=E5M2*E4M3 +// SPARSE GMMA 64x192x64 TN F32+=E4M3*E4M3 template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, GMMA::SparseSel spsel = GMMA::SparseSel::Zero > -struct GMMA_64x112x64_F32E5M2E4M3_RS_TN +struct GMMA_64x192x64_F32E4M3E4M3_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; using ERegisters = uint32_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = float[56]; + using CRegisters = float[96]; CUTE_HOST_DEVICE static void fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, @@ -45455,26 +15777,42 @@ struct GMMA_64x112x64_F32E5M2E4M3_RS_TN float & d44, float & d45, float & d46, float & d47, float & d48, float & d49, float & d50, float & d51, float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + float & d88, float & d89, float & d90, float & d91, + float & d92, float & d93, float & d94, float & d95, uint32_t const& e, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %63, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n112k64.f32.e5m2.e4m3 " + "setp.ne.b32 p, %103, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n192k64.f32.e4m3.e4m3 " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " " %16, %17, %18, %19, %20, %21, %22, %23, " " %24, %25, %26, %27, %28, %29, %30, %31, " " %32, %33, %34, %35, %36, %37, %38, %39, " " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55}," - "{%56, %57, %58, %59}," - " %60," - " %61, %62," - " p, %64, %65;\n" + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + "{%96, %97, %98, %99}," + " %100," + " %101, %102," + " p, %104, %105;\n" "}\n" : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), @@ -45489,33 +15827,42 @@ struct GMMA_64x112x64_F32E5M2E4M3_RS_TN "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), - "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55) + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), + "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91), + "+f"(d92), "+f"(d93), "+f"(d94), "+f"(d95) : "r"(a00), "r"(a01), "r"(a02), "r"(a03), "l"(desc_b), "r"(e), "n"(int32_t(spsel)), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x112x64_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x192x64_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -// SPARSE GMMA 64x128x64 TN F16+=E5M2*E4M3 +// SPARSE GMMA 64x256x64 TN F16+=E4M3*E4M3 template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, GMMA::SparseSel spsel = GMMA::SparseSel::Zero > -struct GMMA_64x128x64_F16E5M2E4M3_SS_TN +struct GMMA_64x256x64_F16E4M3E4M3_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; using ERegisters = uint32_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[32]; + using CRegisters = uint32_t[64]; CUTE_HOST_DEVICE static void fma(uint64_t const& desc_a, @@ -45528,23 +15875,36 @@ struct GMMA_64x128x64_F16E5M2E4M3_SS_TN uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, uint32_t const& e, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %36, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n128k64.f16.e5m2.e4m3 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31}," - " %32," - " %33," - " %34, %35," - " p, %37, %38;\n" + "setp.ne.b32 p, %68, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n256k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + " %64," + " %65," + " %66, %67," + " p, %69, %70;\n" "}\n" : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), @@ -45553,32 +15913,40 @@ struct GMMA_64x128x64_F16E5M2E4M3_SS_TN "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) : "l"(desc_a), "l"(desc_b), "r"(e), "n"(int32_t(spsel)), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x128x64_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x256x64_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// SPARSE GMMA 64x128x64 TN F16+=E5M2*E4M3 +// SPARSE GMMA 64x256x64 TN F16+=E4M3*E4M3 template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, GMMA::SparseSel spsel = GMMA::SparseSel::Zero > -struct GMMA_64x128x64_F16E5M2E4M3_RS_TN +struct GMMA_64x256x64_F16E4M3E4M3_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; using ERegisters = uint32_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[32]; + using CRegisters = uint32_t[64]; CUTE_HOST_DEVICE static void fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, @@ -45591,23 +15959,36 @@ struct GMMA_64x128x64_F16E5M2E4M3_RS_TN uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, uint32_t const& e, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %39, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n128k64.f16.e5m2.e4m3 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31}," - "{%32, %33, %34, %35}," - " %36," - " %37, %38," - " p, %40, %41;\n" + "setp.ne.b32 p, %71, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n256k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + "{%64, %65, %66, %67}," + " %68," + " %69, %70," + " p, %72, %73;\n" "}\n" : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), @@ -45616,61 +15997,86 @@ struct GMMA_64x128x64_F16E5M2E4M3_RS_TN "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) : "r"(a00), "r"(a01), "r"(a02), "r"(a03), "l"(desc_b), "r"(e), "n"(int32_t(spsel)), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x128x64_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x256x64_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// SPARSE GMMA 64x128x64 TN F32+=E5M2*E4M3 +// SPARSE GMMA 64x256x64 TN F32+=E4M3*E4M3 template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, GMMA::SparseSel spsel = GMMA::SparseSel::Zero > -struct GMMA_64x128x64_F32E5M2E4M3_SS_TN +struct GMMA_64x256x64_F32E4M3E4M3_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; using ERegisters = uint32_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = float[64]; + using CRegisters = float[128]; CUTE_HOST_DEVICE static void fma(uint64_t const& desc_a, uint64_t const& desc_b, - float & d00, float & d01, float & d02, float & d03, - float & d04, float & d05, float & d06, float & d07, - float & d08, float & d09, float & d10, float & d11, - float & d12, float & d13, float & d14, float & d15, - float & d16, float & d17, float & d18, float & d19, - float & d20, float & d21, float & d22, float & d23, - float & d24, float & d25, float & d26, float & d27, - float & d28, float & d29, float & d30, float & d31, - float & d32, float & d33, float & d34, float & d35, - float & d36, float & d37, float & d38, float & d39, - float & d40, float & d41, float & d42, float & d43, - float & d44, float & d45, float & d46, float & d47, - float & d48, float & d49, float & d50, float & d51, - float & d52, float & d53, float & d54, float & d55, - float & d56, float & d57, float & d58, float & d59, - float & d60, float & d61, float & d62, float & d63, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + float & d120, float & d121, float & d122, float & d123, + float & d124, float & d125, float & d126, float & d127, uint32_t const& e, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %68, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n128k64.f32.e5m2.e4m3 " + "setp.ne.b32 p, %132, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n256k64.f32.e4m3.e4m3 " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " " %16, %17, %18, %19, %20, %21, %22, %23, " @@ -45678,82 +16084,123 @@ struct GMMA_64x128x64_F32E5M2E4M3_SS_TN " %32, %33, %34, %35, %36, %37, %38, %39, " " %40, %41, %42, %43, %44, %45, %46, %47, " " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63}," - " %64," - " %65," - " %66, %67," - " p, %69, %70;\n" + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + " %128," + " %129," + " %130, %131," + " p, %133, %134;\n" "}\n" - : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), - "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), - "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), - "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), - "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), - "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), - "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), - "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), - "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), - "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), - "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), - "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), - "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), - "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), - "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), - "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63) + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), + "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123), + "+f"(d124), "+f"(d125), "+f"(d126), "+f"(d127) : "l"(desc_a), "l"(desc_b), "r"(e), "n"(int32_t(spsel)), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x128x64_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x256x64_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// SPARSE GMMA 64x128x64 TN F32+=E5M2*E4M3 +// SPARSE GMMA 64x256x64 TN F32+=E4M3*E4M3 template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, GMMA::SparseSel spsel = GMMA::SparseSel::Zero > -struct GMMA_64x128x64_F32E5M2E4M3_RS_TN +struct GMMA_64x256x64_F32E4M3E4M3_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; using ERegisters = uint32_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = float[64]; + using CRegisters = float[128]; CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, uint64_t const& desc_b, - float & d00, float & d01, float & d02, float & d03, - float & d04, float & d05, float & d06, float & d07, - float & d08, float & d09, float & d10, float & d11, - float & d12, float & d13, float & d14, float & d15, - float & d16, float & d17, float & d18, float & d19, - float & d20, float & d21, float & d22, float & d23, - float & d24, float & d25, float & d26, float & d27, - float & d28, float & d29, float & d30, float & d31, - float & d32, float & d33, float & d34, float & d35, - float & d36, float & d37, float & d38, float & d39, - float & d40, float & d41, float & d42, float & d43, - float & d44, float & d45, float & d46, float & d47, - float & d48, float & d49, float & d50, float & d51, - float & d52, float & d53, float & d54, float & d55, - float & d56, float & d57, float & d58, float & d59, - float & d60, float & d61, float & d62, float & d63, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + float & d120, float & d121, float & d122, float & d123, + float & d124, float & d125, float & d126, float & d127, uint32_t const& e, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %71, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n128k64.f32.e5m2.e4m3 " + "setp.ne.b32 p, %135, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n256k64.f32.e4m3.e4m3 " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " " %16, %17, %18, %19, %20, %21, %22, %23, " @@ -45761,846 +16208,555 @@ struct GMMA_64x128x64_F32E5M2E4M3_RS_TN " %32, %33, %34, %35, %36, %37, %38, %39, " " %40, %41, %42, %43, %44, %45, %46, %47, " " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63}," - "{%64, %65, %66, %67}," - " %68," - " %69, %70," - " p, %72, %73;\n" + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + "{%128, %129, %130, %131}," + " %132," + " %133, %134," + " p, %136, %137;\n" "}\n" - : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), - "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), - "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), - "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), - "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), - "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), - "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), - "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), - "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), - "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), - "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), - "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), - "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), - "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), - "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), - "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), + "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123), + "+f"(d124), "+f"(d125), "+f"(d126), "+f"(d127) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), "l"(desc_b), "r"(e), "n"(int32_t(spsel)), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x128x64_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x256x64_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x144x64 TN F16+=E5M2*E4M3 +// SPARSE GMMA 64x8x64 TN F16+=E4M3*E5M2 template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, GMMA::SparseSel spsel = GMMA::SparseSel::Zero > -struct GMMA_64x144x64_F16E5M2E4M3_SS_TN +struct GMMA_64x8x64_F16E4M3E5M2_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; using ERegisters = uint32_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[36]; + using CRegisters = uint32_t[2]; CUTE_HOST_DEVICE static void fma(uint64_t const& desc_a, uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d0, uint32_t & d1, uint32_t const& e, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %40, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n144k64.f16.e5m2.e4m3 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35}," - " %36," - " %37," - " %38, %39," - " p, %41, %42;\n" + "setp.ne.b32 p, %6, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n8k64.f16.e4m3.e5m2 " + "{%0, %1}," + " %2," + " %3," + " %4, %5," + " p, %7, %8;\n" "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35) + : "+r"(d0), "+r"(d1) : "l"(desc_a), "l"(desc_b), "r"(e), "n"(int32_t(spsel)), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x144x64_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x8x64_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x144x64 TN F16+=E5M2*E4M3 +// SPARSE GMMA 64x8x64 TN F16+=E4M3*E5M2 template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, GMMA::SparseSel spsel = GMMA::SparseSel::Zero > -struct GMMA_64x144x64_F16E5M2E4M3_RS_TN +struct GMMA_64x8x64_F16E4M3E5M2_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; using ERegisters = uint32_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[36]; + using CRegisters = uint32_t[2]; CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d0, uint32_t & d1, uint32_t const& e, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %43, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n144k64.f16.e5m2.e4m3 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35}," - "{%36, %37, %38, %39}," - " %40," - " %41, %42," - " p, %44, %45;\n" + "setp.ne.b32 p, %9, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n8k64.f16.e4m3.e5m2 " + "{%0, %1}," + "{%2, %3, %4, %5}," + " %6," + " %7, %8," + " p, %10, %11;\n" "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + : "+r"(d0), "+r"(d1) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), "l"(desc_b), "r"(e), "n"(int32_t(spsel)), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x144x64_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x8x64_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x144x64 TN F32+=E5M2*E4M3 +// SPARSE GMMA 64x8x64 TN F32+=E4M3*E5M2 template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, GMMA::SparseSel spsel = GMMA::SparseSel::Zero > -struct GMMA_64x144x64_F32E5M2E4M3_SS_TN +struct GMMA_64x8x64_F32E4M3E5M2_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; using ERegisters = uint32_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = float[72]; + using CRegisters = float[4]; CUTE_HOST_DEVICE static void fma(uint64_t const& desc_a, uint64_t const& desc_b, - float & d00, float & d01, float & d02, float & d03, - float & d04, float & d05, float & d06, float & d07, - float & d08, float & d09, float & d10, float & d11, - float & d12, float & d13, float & d14, float & d15, - float & d16, float & d17, float & d18, float & d19, - float & d20, float & d21, float & d22, float & d23, - float & d24, float & d25, float & d26, float & d27, - float & d28, float & d29, float & d30, float & d31, - float & d32, float & d33, float & d34, float & d35, - float & d36, float & d37, float & d38, float & d39, - float & d40, float & d41, float & d42, float & d43, - float & d44, float & d45, float & d46, float & d47, - float & d48, float & d49, float & d50, float & d51, - float & d52, float & d53, float & d54, float & d55, - float & d56, float & d57, float & d58, float & d59, - float & d60, float & d61, float & d62, float & d63, - float & d64, float & d65, float & d66, float & d67, - float & d68, float & d69, float & d70, float & d71, + float & d0, float & d1, float & d2, float & d3, uint32_t const& e, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %76, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n144k64.f32.e5m2.e4m3 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71}," - " %72," - " %73," - " %74, %75," - " p, %77, %78;\n" + "setp.ne.b32 p, %8, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n8k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3}," + " %4," + " %5," + " %6, %7," + " p, %9, %10;\n" "}\n" - : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), - "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), - "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), - "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), - "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), - "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), - "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), - "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), - "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), - "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), - "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), - "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), - "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), - "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), - "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), - "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), - "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), - "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71) + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3) : "l"(desc_a), "l"(desc_b), "r"(e), "n"(int32_t(spsel)), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x144x64_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x8x64_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x144x64 TN F32+=E5M2*E4M3 +// SPARSE GMMA 64x8x64 TN F32+=E4M3*E5M2 template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, GMMA::SparseSel spsel = GMMA::SparseSel::Zero > -struct GMMA_64x144x64_F32E5M2E4M3_RS_TN +struct GMMA_64x8x64_F32E4M3E5M2_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; using ERegisters = uint32_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = float[72]; + using CRegisters = float[4]; CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, uint64_t const& desc_b, - float & d00, float & d01, float & d02, float & d03, - float & d04, float & d05, float & d06, float & d07, - float & d08, float & d09, float & d10, float & d11, - float & d12, float & d13, float & d14, float & d15, - float & d16, float & d17, float & d18, float & d19, - float & d20, float & d21, float & d22, float & d23, - float & d24, float & d25, float & d26, float & d27, - float & d28, float & d29, float & d30, float & d31, - float & d32, float & d33, float & d34, float & d35, - float & d36, float & d37, float & d38, float & d39, - float & d40, float & d41, float & d42, float & d43, - float & d44, float & d45, float & d46, float & d47, - float & d48, float & d49, float & d50, float & d51, - float & d52, float & d53, float & d54, float & d55, - float & d56, float & d57, float & d58, float & d59, - float & d60, float & d61, float & d62, float & d63, - float & d64, float & d65, float & d66, float & d67, - float & d68, float & d69, float & d70, float & d71, + float & d0, float & d1, float & d2, float & d3, uint32_t const& e, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %79, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n144k64.f32.e5m2.e4m3 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71}," - "{%72, %73, %74, %75}," - " %76," - " %77, %78," - " p, %80, %81;\n" + "setp.ne.b32 p, %11, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n8k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + " %8," + " %9, %10," + " p, %12, %13;\n" "}\n" - : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), - "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), - "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), - "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), - "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), - "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), - "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), - "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), - "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), - "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), - "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), - "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), - "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), - "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), - "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), - "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), - "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), - "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), "l"(desc_b), "r"(e), "n"(int32_t(spsel)), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x144x64_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x8x64_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x160x64 TN F16+=E5M2*E4M3 +// SPARSE GMMA 64x16x64 TN F16+=E4M3*E5M2 template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, GMMA::SparseSel spsel = GMMA::SparseSel::Zero > -struct GMMA_64x160x64_F16E5M2E4M3_SS_TN +struct GMMA_64x16x64_F16E4M3E5M2_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; using ERegisters = uint32_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[40]; + using CRegisters = uint32_t[4]; CUTE_HOST_DEVICE static void fma(uint64_t const& desc_a, uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, uint32_t const& e, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %44, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n160k64.f16.e5m2.e4m3 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39}," - " %40," - " %41," - " %42, %43," - " p, %45, %46;\n" + "setp.ne.b32 p, %8, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n16k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3}," + " %4," + " %5," + " %6, %7," + " p, %9, %10;\n" "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) : "l"(desc_a), "l"(desc_b), "r"(e), "n"(int32_t(spsel)), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x160x64_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x16x64_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x160x64 TN F16+=E5M2*E4M3 +// SPARSE GMMA 64x16x64 TN F16+=E4M3*E5M2 template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, GMMA::SparseSel spsel = GMMA::SparseSel::Zero > -struct GMMA_64x160x64_F16E5M2E4M3_RS_TN +struct GMMA_64x16x64_F16E4M3E5M2_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; using ERegisters = uint32_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[40]; + using CRegisters = uint32_t[4]; CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, uint32_t const& e, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %47, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n160k64.f16.e5m2.e4m3 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39}," - "{%40, %41, %42, %43}," - " %44," - " %45, %46," - " p, %48, %49;\n" + "setp.ne.b32 p, %11, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n16k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + " %8," + " %9, %10," + " p, %12, %13;\n" "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), "l"(desc_b), "r"(e), "n"(int32_t(spsel)), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x160x64_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x16x64_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x160x64 TN F32+=E5M2*E4M3 +// SPARSE GMMA 64x16x64 TN F32+=E4M3*E5M2 template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, GMMA::SparseSel spsel = GMMA::SparseSel::Zero > -struct GMMA_64x160x64_F32E5M2E4M3_SS_TN +struct GMMA_64x16x64_F32E4M3E5M2_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; using ERegisters = uint32_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = float[80]; + using CRegisters = float[8]; CUTE_HOST_DEVICE static void fma(uint64_t const& desc_a, uint64_t const& desc_b, - float & d00, float & d01, float & d02, float & d03, - float & d04, float & d05, float & d06, float & d07, - float & d08, float & d09, float & d10, float & d11, - float & d12, float & d13, float & d14, float & d15, - float & d16, float & d17, float & d18, float & d19, - float & d20, float & d21, float & d22, float & d23, - float & d24, float & d25, float & d26, float & d27, - float & d28, float & d29, float & d30, float & d31, - float & d32, float & d33, float & d34, float & d35, - float & d36, float & d37, float & d38, float & d39, - float & d40, float & d41, float & d42, float & d43, - float & d44, float & d45, float & d46, float & d47, - float & d48, float & d49, float & d50, float & d51, - float & d52, float & d53, float & d54, float & d55, - float & d56, float & d57, float & d58, float & d59, - float & d60, float & d61, float & d62, float & d63, - float & d64, float & d65, float & d66, float & d67, - float & d68, float & d69, float & d70, float & d71, - float & d72, float & d73, float & d74, float & d75, - float & d76, float & d77, float & d78, float & d79, + float & d0, float & d1, float & d2, float & d3, + float & d4, float & d5, float & d6, float & d7, uint32_t const& e, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %84, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n160k64.f32.e5m2.e4m3 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79}," - " %80," - " %81," - " %82, %83," - " p, %85, %86;\n" + "setp.ne.b32 p, %12, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n16k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + " %8," + " %9," + " %10, %11," + " p, %13, %14;\n" "}\n" - : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), - "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), - "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), - "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), - "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), - "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), - "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), - "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), - "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), - "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), - "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), - "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), - "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), - "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), - "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), - "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), - "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), - "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), - "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), - "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79) + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3), + "+f"(d4), "+f"(d5), "+f"(d6), "+f"(d7) : "l"(desc_a), "l"(desc_b), "r"(e), "n"(int32_t(spsel)), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x160x64_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x16x64_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x160x64 TN F32+=E5M2*E4M3 +// SPARSE GMMA 64x16x64 TN F32+=E4M3*E5M2 template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, GMMA::SparseSel spsel = GMMA::SparseSel::Zero > -struct GMMA_64x160x64_F32E5M2E4M3_RS_TN +struct GMMA_64x16x64_F32E4M3E5M2_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; using ERegisters = uint32_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = float[80]; + using CRegisters = float[8]; CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, uint64_t const& desc_b, - float & d00, float & d01, float & d02, float & d03, - float & d04, float & d05, float & d06, float & d07, - float & d08, float & d09, float & d10, float & d11, - float & d12, float & d13, float & d14, float & d15, - float & d16, float & d17, float & d18, float & d19, - float & d20, float & d21, float & d22, float & d23, - float & d24, float & d25, float & d26, float & d27, - float & d28, float & d29, float & d30, float & d31, - float & d32, float & d33, float & d34, float & d35, - float & d36, float & d37, float & d38, float & d39, - float & d40, float & d41, float & d42, float & d43, - float & d44, float & d45, float & d46, float & d47, - float & d48, float & d49, float & d50, float & d51, - float & d52, float & d53, float & d54, float & d55, - float & d56, float & d57, float & d58, float & d59, - float & d60, float & d61, float & d62, float & d63, - float & d64, float & d65, float & d66, float & d67, - float & d68, float & d69, float & d70, float & d71, - float & d72, float & d73, float & d74, float & d75, - float & d76, float & d77, float & d78, float & d79, + float & d0, float & d1, float & d2, float & d3, + float & d4, float & d5, float & d6, float & d7, uint32_t const& e, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %87, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n160k64.f32.e5m2.e4m3 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79}," - "{%80, %81, %82, %83}," - " %84," - " %85, %86," - " p, %88, %89;\n" + "setp.ne.b32 p, %15, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n16k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + "{%8, %9, %10, %11}," + " %12," + " %13, %14," + " p, %16, %17;\n" "}\n" - : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), - "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), - "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), - "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), - "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), - "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), - "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), - "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), - "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), - "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), - "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), - "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), - "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), - "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), - "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), - "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), - "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), - "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), - "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), - "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3), + "+f"(d4), "+f"(d5), "+f"(d6), "+f"(d7) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), "l"(desc_b), "r"(e), "n"(int32_t(spsel)), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x160x64_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x16x64_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x176x64 TN F16+=E5M2*E4M3 +// SPARSE GMMA 64x32x64 TN F16+=E4M3*E5M2 template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, GMMA::SparseSel spsel = GMMA::SparseSel::Zero > -struct GMMA_64x176x64_F16E5M2E4M3_SS_TN +struct GMMA_64x32x64_F16E4M3E5M2_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; using ERegisters = uint32_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[44]; + using CRegisters = uint32_t[8]; CUTE_HOST_DEVICE static void fma(uint64_t const& desc_a, uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, uint32_t const& e, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %48, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n176k64.f16.e5m2.e4m3 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43}," - " %44," - " %45," - " %46, %47," - " p, %49, %50;\n" + "setp.ne.b32 p, %12, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n32k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + " %8," + " %9," + " %10, %11," + " p, %13, %14;\n" "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43) + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) : "l"(desc_a), "l"(desc_b), "r"(e), "n"(int32_t(spsel)), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x176x64_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x32x64_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x176x64 TN F16+=E5M2*E4M3 +// SPARSE GMMA 64x32x64 TN F16+=E4M3*E5M2 template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, GMMA::SparseSel spsel = GMMA::SparseSel::Zero > -struct GMMA_64x176x64_F16E5M2E4M3_RS_TN +struct GMMA_64x32x64_F16E4M3E5M2_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; using ERegisters = uint32_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[44]; + using CRegisters = uint32_t[8]; CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, uint32_t const& e, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %51, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n176k64.f16.e5m2.e4m3 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43}," - "{%44, %45, %46, %47}," - " %48," - " %49, %50," - " p, %52, %53;\n" + "setp.ne.b32 p, %15, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n32k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + "{%8, %9, %10, %11}," + " %12," + " %13, %14," + " p, %16, %17;\n" "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), "l"(desc_b), "r"(e), "n"(int32_t(spsel)), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x176x64_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x32x64_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x176x64 TN F32+=E5M2*E4M3 +// SPARSE GMMA 64x32x64 TN F32+=E4M3*E5M2 template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, GMMA::SparseSel spsel = GMMA::SparseSel::Zero > -struct GMMA_64x176x64_F32E5M2E4M3_SS_TN +struct GMMA_64x32x64_F32E4M3E5M2_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; using ERegisters = uint32_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = float[88]; + using CRegisters = float[16]; CUTE_HOST_DEVICE static void fma(uint64_t const& desc_a, @@ -46609,98 +16765,52 @@ struct GMMA_64x176x64_F32E5M2E4M3_SS_TN float & d04, float & d05, float & d06, float & d07, float & d08, float & d09, float & d10, float & d11, float & d12, float & d13, float & d14, float & d15, - float & d16, float & d17, float & d18, float & d19, - float & d20, float & d21, float & d22, float & d23, - float & d24, float & d25, float & d26, float & d27, - float & d28, float & d29, float & d30, float & d31, - float & d32, float & d33, float & d34, float & d35, - float & d36, float & d37, float & d38, float & d39, - float & d40, float & d41, float & d42, float & d43, - float & d44, float & d45, float & d46, float & d47, - float & d48, float & d49, float & d50, float & d51, - float & d52, float & d53, float & d54, float & d55, - float & d56, float & d57, float & d58, float & d59, - float & d60, float & d61, float & d62, float & d63, - float & d64, float & d65, float & d66, float & d67, - float & d68, float & d69, float & d70, float & d71, - float & d72, float & d73, float & d74, float & d75, - float & d76, float & d77, float & d78, float & d79, - float & d80, float & d81, float & d82, float & d83, - float & d84, float & d85, float & d86, float & d87, uint32_t const& e, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %92, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n176k64.f32.e5m2.e4m3 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87}," - " %88," - " %89," - " %90, %91," - " p, %93, %94;\n" + "setp.ne.b32 p, %20, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n32k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + " %16," + " %17," + " %18, %19," + " p, %21, %22;\n" "}\n" : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), - "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), - "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), - "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), - "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), - "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), - "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), - "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), - "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), - "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), - "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), - "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), - "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), - "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), - "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), - "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), - "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), - "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), - "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), - "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87) + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15) : "l"(desc_a), "l"(desc_b), "r"(e), "n"(int32_t(spsel)), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x176x64_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x32x64_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x176x64 TN F32+=E5M2*E4M3 +// SPARSE GMMA 64x32x64 TN F32+=E4M3*E5M2 template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, GMMA::SparseSel spsel = GMMA::SparseSel::Zero > -struct GMMA_64x176x64_F32E5M2E4M3_RS_TN +struct GMMA_64x32x64_F32E4M3E5M2_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; using ERegisters = uint32_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = float[88]; + using CRegisters = float[16]; CUTE_HOST_DEVICE static void fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, @@ -46709,97 +16819,52 @@ struct GMMA_64x176x64_F32E5M2E4M3_RS_TN float & d04, float & d05, float & d06, float & d07, float & d08, float & d09, float & d10, float & d11, float & d12, float & d13, float & d14, float & d15, - float & d16, float & d17, float & d18, float & d19, - float & d20, float & d21, float & d22, float & d23, - float & d24, float & d25, float & d26, float & d27, - float & d28, float & d29, float & d30, float & d31, - float & d32, float & d33, float & d34, float & d35, - float & d36, float & d37, float & d38, float & d39, - float & d40, float & d41, float & d42, float & d43, - float & d44, float & d45, float & d46, float & d47, - float & d48, float & d49, float & d50, float & d51, - float & d52, float & d53, float & d54, float & d55, - float & d56, float & d57, float & d58, float & d59, - float & d60, float & d61, float & d62, float & d63, - float & d64, float & d65, float & d66, float & d67, - float & d68, float & d69, float & d70, float & d71, - float & d72, float & d73, float & d74, float & d75, - float & d76, float & d77, float & d78, float & d79, - float & d80, float & d81, float & d82, float & d83, - float & d84, float & d85, float & d86, float & d87, uint32_t const& e, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %95, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n176k64.f32.e5m2.e4m3 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87}," - "{%88, %89, %90, %91}," - " %92," - " %93, %94," - " p, %96, %97;\n" + "setp.ne.b32 p, %23, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n32k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + "{%16, %17, %18, %19}," + " %20," + " %21, %22," + " p, %24, %25;\n" "}\n" : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), - "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), - "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), - "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), - "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), - "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), - "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), - "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), - "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), - "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), - "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), - "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), - "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), - "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), - "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), - "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), - "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), - "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), - "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), - "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87) + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15) : "r"(a00), "r"(a01), "r"(a02), "r"(a03), "l"(desc_b), "r"(e), "n"(int32_t(spsel)), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x176x64_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x32x64_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -// SPARSE GMMA 64x192x64 TN F16+=E5M2*E4M3 +// SPARSE GMMA 64x64x64 TN F16+=E4M3*E5M2 template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, GMMA::SparseSel spsel = GMMA::SparseSel::Zero > -struct GMMA_64x192x64_F16E5M2E4M3_SS_TN +struct GMMA_64x64x64_F16E4M3E5M2_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; using ERegisters = uint32_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[48]; + using CRegisters = uint32_t[16]; CUTE_HOST_DEVICE static void fma(uint64_t const& desc_a, @@ -46808,71 +16873,52 @@ struct GMMA_64x192x64_F16E5M2E4M3_SS_TN uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, uint32_t const& e, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %52, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n192k64.f16.e5m2.e4m3 " + "setp.ne.b32 p, %20, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n64k64.f16.e4m3.e5m2 " "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47}," - " %48," - " %49," - " %50, %51," - " p, %53, %54;\n" + " %8, %9, %10, %11, %12, %13, %14, %15}," + " %16," + " %17," + " %18, %19," + " p, %21, %22;\n" "}\n" : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), - "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) : "l"(desc_a), "l"(desc_b), "r"(e), "n"(int32_t(spsel)), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x192x64_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x64x64_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// SPARSE GMMA 64x192x64 TN F16+=E5M2*E4M3 +// SPARSE GMMA 64x64x64 TN F16+=E4M3*E5M2 template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, GMMA::SparseSel spsel = GMMA::SparseSel::Zero > -struct GMMA_64x192x64_F16E5M2E4M3_RS_TN +struct GMMA_64x64x64_F16E4M3E5M2_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; using ERegisters = uint32_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[48]; + using CRegisters = uint32_t[16]; CUTE_HOST_DEVICE static void fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, @@ -46881,124 +16927,82 @@ struct GMMA_64x192x64_F16E5M2E4M3_RS_TN uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, uint32_t const& e, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %55, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n192k64.f16.e5m2.e4m3 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47}," - "{%48, %49, %50, %51}," - " %52," - " %53, %54," - " p, %56, %57;\n" + "setp.ne.b32 p, %23, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n64k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + "{%16, %17, %18, %19}," + " %20," + " %21, %22," + " p, %24, %25;\n" "}\n" : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), - "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) : "r"(a00), "r"(a01), "r"(a02), "r"(a03), "l"(desc_b), "r"(e), "n"(int32_t(spsel)), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x192x64_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x64x64_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// SPARSE GMMA 64x192x64 TN F32+=E5M2*E4M3 +// SPARSE GMMA 64x64x64 TN F32+=E4M3*E5M2 template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, GMMA::SparseSel spsel = GMMA::SparseSel::Zero > -struct GMMA_64x192x64_F32E5M2E4M3_SS_TN +struct GMMA_64x64x64_F32E4M3E5M2_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; using ERegisters = uint32_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = float[96]; + using CRegisters = float[32]; CUTE_HOST_DEVICE static void fma(uint64_t const& desc_a, uint64_t const& desc_b, - float & d00, float & d01, float & d02, float & d03, - float & d04, float & d05, float & d06, float & d07, - float & d08, float & d09, float & d10, float & d11, - float & d12, float & d13, float & d14, float & d15, - float & d16, float & d17, float & d18, float & d19, - float & d20, float & d21, float & d22, float & d23, - float & d24, float & d25, float & d26, float & d27, - float & d28, float & d29, float & d30, float & d31, - float & d32, float & d33, float & d34, float & d35, - float & d36, float & d37, float & d38, float & d39, - float & d40, float & d41, float & d42, float & d43, - float & d44, float & d45, float & d46, float & d47, - float & d48, float & d49, float & d50, float & d51, - float & d52, float & d53, float & d54, float & d55, - float & d56, float & d57, float & d58, float & d59, - float & d60, float & d61, float & d62, float & d63, - float & d64, float & d65, float & d66, float & d67, - float & d68, float & d69, float & d70, float & d71, - float & d72, float & d73, float & d74, float & d75, - float & d76, float & d77, float & d78, float & d79, - float & d80, float & d81, float & d82, float & d83, - float & d84, float & d85, float & d86, float & d87, - float & d88, float & d89, float & d90, float & d91, - float & d92, float & d93, float & d94, float & d95, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, uint32_t const& e, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %100, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n192k64.f32.e5m2.e4m3 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87, " - " %88, %89, %90, %91, %92, %93, %94, %95}," - " %96," - " %97," - " %98, %99," - " p, %101, %102;\n" + "setp.ne.b32 p, %36, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n64k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + " %32," + " %33," + " %34, %35," + " p, %37, %38;\n" "}\n" : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), @@ -47007,48 +17011,32 @@ struct GMMA_64x192x64_F32E5M2E4M3_SS_TN "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), - "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), - "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), - "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), - "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), - "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), - "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), - "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), - "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), - "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), - "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), - "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), - "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), - "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), - "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), - "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), - "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91), - "+f"(d92), "+f"(d93), "+f"(d94), "+f"(d95) + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31) : "l"(desc_a), "l"(desc_b), "r"(e), "n"(int32_t(spsel)), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x192x64_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x64x64_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// SPARSE GMMA 64x192x64 TN F32+=E5M2*E4M3 +// SPARSE GMMA 64x64x64 TN F32+=E4M3*E5M2 template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, GMMA::SparseSel spsel = GMMA::SparseSel::Zero > -struct GMMA_64x192x64_F32E5M2E4M3_RS_TN +struct GMMA_64x64x64_F32E4M3E5M2_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; using ERegisters = uint32_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = float[96]; + using CRegisters = float[32]; CUTE_HOST_DEVICE static void fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, @@ -47061,47 +17049,24 @@ struct GMMA_64x192x64_F32E5M2E4M3_RS_TN float & d20, float & d21, float & d22, float & d23, float & d24, float & d25, float & d26, float & d27, float & d28, float & d29, float & d30, float & d31, - float & d32, float & d33, float & d34, float & d35, - float & d36, float & d37, float & d38, float & d39, - float & d40, float & d41, float & d42, float & d43, - float & d44, float & d45, float & d46, float & d47, - float & d48, float & d49, float & d50, float & d51, - float & d52, float & d53, float & d54, float & d55, - float & d56, float & d57, float & d58, float & d59, - float & d60, float & d61, float & d62, float & d63, - float & d64, float & d65, float & d66, float & d67, - float & d68, float & d69, float & d70, float & d71, - float & d72, float & d73, float & d74, float & d75, - float & d76, float & d77, float & d78, float & d79, - float & d80, float & d81, float & d82, float & d83, - float & d84, float & d85, float & d86, float & d87, - float & d88, float & d89, float & d90, float & d91, - float & d92, float & d93, float & d94, float & d95, uint32_t const& e, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %103, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n192k64.f32.e5m2.e4m3 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87, " - " %88, %89, %90, %91, %92, %93, %94, %95}," - "{%96, %97, %98, %99}," - " %100," - " %101, %102," - " p, %104, %105;\n" + "setp.ne.b32 p, %39, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n64k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + "{%32, %33, %34, %35}," + " %36," + " %37, %38," + " p, %40, %41;\n" "}\n" : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), @@ -47110,49 +17075,32 @@ struct GMMA_64x192x64_F32E5M2E4M3_RS_TN "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), - "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), - "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), - "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), - "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), - "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), - "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), - "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), - "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), - "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), - "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), - "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), - "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), - "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), - "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), - "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), - "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91), - "+f"(d92), "+f"(d93), "+f"(d94), "+f"(d95) + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31) : "r"(a00), "r"(a01), "r"(a02), "r"(a03), "l"(desc_b), "r"(e), "n"(int32_t(spsel)), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x192x64_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x64x64_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x208x64 TN F16+=E5M2*E4M3 +// SPARSE GMMA 64x96x64 TN F16+=E4M3*E5M2 template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, GMMA::SparseSel spsel = GMMA::SparseSel::Zero > -struct GMMA_64x208x64_F16E5M2E4M3_SS_TN +struct GMMA_64x96x64_F16E4M3E5M2_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; using ERegisters = uint32_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[52]; + using CRegisters = uint32_t[24]; CUTE_HOST_DEVICE static void fma(uint64_t const& desc_a, @@ -47163,74 +17111,55 @@ struct GMMA_64x208x64_F16E5M2E4M3_SS_TN uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, - uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, uint32_t const& e, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %56, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n208k64.f16.e5m2.e4m3 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51}," - " %52," - " %53," - " %54, %55," - " p, %57, %58;\n" + "setp.ne.b32 p, %28, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n96k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + " %24," + " %25," + " %26, %27," + " p, %29, %30;\n" "}\n" : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), - "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), - "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51) + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) : "l"(desc_a), "l"(desc_b), "r"(e), "n"(int32_t(spsel)), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x208x64_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x96x64_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x208x64 TN F16+=E5M2*E4M3 +// SPARSE GMMA 64x96x64 TN F16+=E4M3*E5M2 template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, GMMA::SparseSel spsel = GMMA::SparseSel::Zero > -struct GMMA_64x208x64_F16E5M2E4M3_RS_TN +struct GMMA_64x96x64_F16E4M3E5M2_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; using ERegisters = uint32_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[52]; + using CRegisters = uint32_t[24]; CUTE_HOST_DEVICE static void fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, @@ -47241,294 +17170,203 @@ struct GMMA_64x208x64_F16E5M2E4M3_RS_TN uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, - uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, uint32_t const& e, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %59, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n208k64.f16.e5m2.e4m3 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51}," - "{%52, %53, %54, %55}," - " %56," - " %57, %58," - " p, %60, %61;\n" + "setp.ne.b32 p, %31, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n96k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + "{%24, %25, %26, %27}," + " %28," + " %29, %30," + " p, %32, %33;\n" "}\n" : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), - "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), - "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51) + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) : "r"(a00), "r"(a01), "r"(a02), "r"(a03), "l"(desc_b), "r"(e), "n"(int32_t(spsel)), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x208x64_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x96x64_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x208x64 TN F32+=E5M2*E4M3 +// SPARSE GMMA 64x96x64 TN F32+=E4M3*E5M2 template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, GMMA::SparseSel spsel = GMMA::SparseSel::Zero > -struct GMMA_64x208x64_F32E5M2E4M3_SS_TN +struct GMMA_64x96x64_F32E4M3E5M2_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; using ERegisters = uint32_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = float[104]; + using CRegisters = float[48]; CUTE_HOST_DEVICE static void fma(uint64_t const& desc_a, uint64_t const& desc_b, - float & d000, float & d001, float & d002, float & d003, - float & d004, float & d005, float & d006, float & d007, - float & d008, float & d009, float & d010, float & d011, - float & d012, float & d013, float & d014, float & d015, - float & d016, float & d017, float & d018, float & d019, - float & d020, float & d021, float & d022, float & d023, - float & d024, float & d025, float & d026, float & d027, - float & d028, float & d029, float & d030, float & d031, - float & d032, float & d033, float & d034, float & d035, - float & d036, float & d037, float & d038, float & d039, - float & d040, float & d041, float & d042, float & d043, - float & d044, float & d045, float & d046, float & d047, - float & d048, float & d049, float & d050, float & d051, - float & d052, float & d053, float & d054, float & d055, - float & d056, float & d057, float & d058, float & d059, - float & d060, float & d061, float & d062, float & d063, - float & d064, float & d065, float & d066, float & d067, - float & d068, float & d069, float & d070, float & d071, - float & d072, float & d073, float & d074, float & d075, - float & d076, float & d077, float & d078, float & d079, - float & d080, float & d081, float & d082, float & d083, - float & d084, float & d085, float & d086, float & d087, - float & d088, float & d089, float & d090, float & d091, - float & d092, float & d093, float & d094, float & d095, - float & d096, float & d097, float & d098, float & d099, - float & d100, float & d101, float & d102, float & d103, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, uint32_t const& e, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %108, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n208k64.f32.e5m2.e4m3 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87, " - " %88, %89, %90, %91, %92, %93, %94, %95, " - " %96, %97, %98, %99, %100, %101, %102, %103}," - " %104," - " %105," - " %106, %107," - " p, %109, %110;\n" + "setp.ne.b32 p, %52, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n96k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + " %48," + " %49," + " %50, %51," + " p, %53, %54;\n" "}\n" - : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), - "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), - "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), - "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), - "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), - "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), - "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), - "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), - "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), - "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), - "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), - "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), - "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), - "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), - "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), - "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), - "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), - "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), - "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), - "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), - "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), - "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), - "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), - "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), - "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), - "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103) + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47) : "l"(desc_a), "l"(desc_b), "r"(e), "n"(int32_t(spsel)), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x208x64_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x96x64_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x208x64 TN F32+=E5M2*E4M3 +// SPARSE GMMA 64x96x64 TN F32+=E4M3*E5M2 template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, GMMA::SparseSel spsel = GMMA::SparseSel::Zero > -struct GMMA_64x208x64_F32E5M2E4M3_RS_TN +struct GMMA_64x96x64_F32E4M3E5M2_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; using ERegisters = uint32_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = float[104]; + using CRegisters = float[48]; CUTE_HOST_DEVICE static void - fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, uint64_t const& desc_b, - float & d000, float & d001, float & d002, float & d003, - float & d004, float & d005, float & d006, float & d007, - float & d008, float & d009, float & d010, float & d011, - float & d012, float & d013, float & d014, float & d015, - float & d016, float & d017, float & d018, float & d019, - float & d020, float & d021, float & d022, float & d023, - float & d024, float & d025, float & d026, float & d027, - float & d028, float & d029, float & d030, float & d031, - float & d032, float & d033, float & d034, float & d035, - float & d036, float & d037, float & d038, float & d039, - float & d040, float & d041, float & d042, float & d043, - float & d044, float & d045, float & d046, float & d047, - float & d048, float & d049, float & d050, float & d051, - float & d052, float & d053, float & d054, float & d055, - float & d056, float & d057, float & d058, float & d059, - float & d060, float & d061, float & d062, float & d063, - float & d064, float & d065, float & d066, float & d067, - float & d068, float & d069, float & d070, float & d071, - float & d072, float & d073, float & d074, float & d075, - float & d076, float & d077, float & d078, float & d079, - float & d080, float & d081, float & d082, float & d083, - float & d084, float & d085, float & d086, float & d087, - float & d088, float & d089, float & d090, float & d091, - float & d092, float & d093, float & d094, float & d095, - float & d096, float & d097, float & d098, float & d099, - float & d100, float & d101, float & d102, float & d103, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, uint32_t const& e, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %111, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n208k64.f32.e5m2.e4m3 " + "setp.ne.b32 p, %55, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n96k64.f32.e4m3.e5m2 " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " " %16, %17, %18, %19, %20, %21, %22, %23, " " %24, %25, %26, %27, %28, %29, %30, %31, " " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87, " - " %88, %89, %90, %91, %92, %93, %94, %95, " - " %96, %97, %98, %99, %100, %101, %102, %103}," - "{%104, %105, %106, %107}," - " %108," - " %109, %110," - " p, %112, %113;\n" + " %40, %41, %42, %43, %44, %45, %46, %47}," + "{%48, %49, %50, %51}," + " %52," + " %53, %54," + " p, %56, %57;\n" "}\n" - : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), - "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), - "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), - "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), - "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), - "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), - "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), - "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), - "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), - "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), - "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), - "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), - "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), - "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), - "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), - "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), - "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), - "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), - "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), - "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), - "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), - "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), - "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), - "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), - "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), - "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103) - : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), "l"(desc_b), "r"(e), "n"(int32_t(spsel)), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x208x64_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x96x64_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x224x64 TN F16+=E5M2*E4M3 +// SPARSE GMMA 64x128x64 TN F16+=E4M3*E5M2 template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, GMMA::SparseSel spsel = GMMA::SparseSel::Zero > -struct GMMA_64x224x64_F16E5M2E4M3_SS_TN +struct GMMA_64x128x64_F16E4M3E5M2_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; using ERegisters = uint32_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[56]; + using CRegisters = uint32_t[32]; CUTE_HOST_DEVICE static void fma(uint64_t const& desc_a, @@ -47541,32 +17379,24 @@ struct GMMA_64x224x64_F16E5M2E4M3_SS_TN uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, - uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, - uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, uint32_t const& e, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %60, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n224k64.f16.e5m2.e4m3 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55}," - " %56," - " %57," - " %58, %59," - " p, %61, %62;\n" + "setp.ne.b32 p, %36, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n128k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + " %32," + " %33," + " %34, %35," + " p, %37, %38;\n" "}\n" : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), @@ -47575,40 +17405,32 @@ struct GMMA_64x224x64_F16E5M2E4M3_SS_TN "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), - "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), - "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), - "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) : "l"(desc_a), "l"(desc_b), "r"(e), "n"(int32_t(spsel)), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x224x64_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x128x64_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x224x64 TN F16+=E5M2*E4M3 +// SPARSE GMMA 64x128x64 TN F16+=E4M3*E5M2 template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, GMMA::SparseSel spsel = GMMA::SparseSel::Zero > -struct GMMA_64x224x64_F16E5M2E4M3_RS_TN +struct GMMA_64x128x64_F16E4M3E5M2_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; using ERegisters = uint32_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[56]; + using CRegisters = uint32_t[32]; CUTE_HOST_DEVICE static void fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, @@ -47621,32 +17443,24 @@ struct GMMA_64x224x64_F16E5M2E4M3_RS_TN uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, - uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, - uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, uint32_t const& e, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %63, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n224k64.f16.e5m2.e4m3 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55}," - "{%56, %57, %58, %59}," - " %60," - " %61, %62," - " p, %64, %65;\n" + "setp.ne.b32 p, %39, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n128k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + "{%32, %33, %34, %35}," + " %36," + " %37, %38," + " p, %40, %41;\n" "}\n" : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), @@ -47655,81 +17469,62 @@ struct GMMA_64x224x64_F16E5M2E4M3_RS_TN "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), - "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), - "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), - "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) : "r"(a00), "r"(a01), "r"(a02), "r"(a03), "l"(desc_b), "r"(e), "n"(int32_t(spsel)), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x224x64_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x128x64_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x224x64 TN F32+=E5M2*E4M3 +// SPARSE GMMA 64x128x64 TN F32+=E4M3*E5M2 template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, GMMA::SparseSel spsel = GMMA::SparseSel::Zero > -struct GMMA_64x224x64_F32E5M2E4M3_SS_TN +struct GMMA_64x128x64_F32E4M3E5M2_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; using ERegisters = uint32_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = float[112]; + using CRegisters = float[64]; CUTE_HOST_DEVICE static void fma(uint64_t const& desc_a, uint64_t const& desc_b, - float & d000, float & d001, float & d002, float & d003, - float & d004, float & d005, float & d006, float & d007, - float & d008, float & d009, float & d010, float & d011, - float & d012, float & d013, float & d014, float & d015, - float & d016, float & d017, float & d018, float & d019, - float & d020, float & d021, float & d022, float & d023, - float & d024, float & d025, float & d026, float & d027, - float & d028, float & d029, float & d030, float & d031, - float & d032, float & d033, float & d034, float & d035, - float & d036, float & d037, float & d038, float & d039, - float & d040, float & d041, float & d042, float & d043, - float & d044, float & d045, float & d046, float & d047, - float & d048, float & d049, float & d050, float & d051, - float & d052, float & d053, float & d054, float & d055, - float & d056, float & d057, float & d058, float & d059, - float & d060, float & d061, float & d062, float & d063, - float & d064, float & d065, float & d066, float & d067, - float & d068, float & d069, float & d070, float & d071, - float & d072, float & d073, float & d074, float & d075, - float & d076, float & d077, float & d078, float & d079, - float & d080, float & d081, float & d082, float & d083, - float & d084, float & d085, float & d086, float & d087, - float & d088, float & d089, float & d090, float & d091, - float & d092, float & d093, float & d094, float & d095, - float & d096, float & d097, float & d098, float & d099, - float & d100, float & d101, float & d102, float & d103, - float & d104, float & d105, float & d106, float & d107, - float & d108, float & d109, float & d110, float & d111, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, uint32_t const& e, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %116, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n224k64.f32.e5m2.e4m3 " + "setp.ne.b32 p, %68, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n128k64.f32.e4m3.e5m2 " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " " %16, %17, %18, %19, %20, %21, %22, %23, " @@ -47737,114 +17532,83 @@ struct GMMA_64x224x64_F32E5M2E4M3_SS_TN " %32, %33, %34, %35, %36, %37, %38, %39, " " %40, %41, %42, %43, %44, %45, %46, %47, " " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87, " - " %88, %89, %90, %91, %92, %93, %94, %95, " - " %96, %97, %98, %99, %100, %101, %102, %103, " - " %104, %105, %106, %107, %108, %109, %110, %111}," - " %112," - " %113," - " %114, %115," - " p, %117, %118;\n" + " %56, %57, %58, %59, %60, %61, %62, %63}," + " %64," + " %65," + " %66, %67," + " p, %69, %70;\n" "}\n" - : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), - "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), - "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), - "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), - "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), - "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), - "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), - "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), - "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), - "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), - "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), - "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), - "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), - "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), - "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), - "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), - "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), - "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), - "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), - "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), - "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), - "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), - "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), - "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), - "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), - "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), - "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), - "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111) + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63) : "l"(desc_a), "l"(desc_b), "r"(e), "n"(int32_t(spsel)), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x224x64_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x128x64_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x224x64 TN F32+=E5M2*E4M3 +// SPARSE GMMA 64x128x64 TN F32+=E4M3*E5M2 template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, GMMA::SparseSel spsel = GMMA::SparseSel::Zero > -struct GMMA_64x224x64_F32E5M2E4M3_RS_TN +struct GMMA_64x128x64_F32E4M3E5M2_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; using ERegisters = uint32_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = float[112]; + using CRegisters = float[64]; CUTE_HOST_DEVICE static void - fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, uint64_t const& desc_b, - float & d000, float & d001, float & d002, float & d003, - float & d004, float & d005, float & d006, float & d007, - float & d008, float & d009, float & d010, float & d011, - float & d012, float & d013, float & d014, float & d015, - float & d016, float & d017, float & d018, float & d019, - float & d020, float & d021, float & d022, float & d023, - float & d024, float & d025, float & d026, float & d027, - float & d028, float & d029, float & d030, float & d031, - float & d032, float & d033, float & d034, float & d035, - float & d036, float & d037, float & d038, float & d039, - float & d040, float & d041, float & d042, float & d043, - float & d044, float & d045, float & d046, float & d047, - float & d048, float & d049, float & d050, float & d051, - float & d052, float & d053, float & d054, float & d055, - float & d056, float & d057, float & d058, float & d059, - float & d060, float & d061, float & d062, float & d063, - float & d064, float & d065, float & d066, float & d067, - float & d068, float & d069, float & d070, float & d071, - float & d072, float & d073, float & d074, float & d075, - float & d076, float & d077, float & d078, float & d079, - float & d080, float & d081, float & d082, float & d083, - float & d084, float & d085, float & d086, float & d087, - float & d088, float & d089, float & d090, float & d091, - float & d092, float & d093, float & d094, float & d095, - float & d096, float & d097, float & d098, float & d099, - float & d100, float & d101, float & d102, float & d103, - float & d104, float & d105, float & d106, float & d107, - float & d108, float & d109, float & d110, float & d111, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, uint32_t const& e, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %119, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n224k64.f32.e5m2.e4m3 " + "setp.ne.b32 p, %71, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n128k64.f32.e4m3.e5m2 " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " " %16, %17, %18, %19, %20, %21, %22, %23, " @@ -47852,73 +17616,53 @@ struct GMMA_64x224x64_F32E5M2E4M3_RS_TN " %32, %33, %34, %35, %36, %37, %38, %39, " " %40, %41, %42, %43, %44, %45, %46, %47, " " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87, " - " %88, %89, %90, %91, %92, %93, %94, %95, " - " %96, %97, %98, %99, %100, %101, %102, %103, " - " %104, %105, %106, %107, %108, %109, %110, %111}," - "{%112, %113, %114, %115}," - " %116," - " %117, %118," - " p, %120, %121;\n" + " %56, %57, %58, %59, %60, %61, %62, %63}," + "{%64, %65, %66, %67}," + " %68," + " %69, %70," + " p, %72, %73;\n" "}\n" - : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), - "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), - "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), - "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), - "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), - "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), - "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), - "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), - "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), - "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), - "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), - "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), - "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), - "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), - "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), - "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), - "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), - "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), - "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), - "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), - "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), - "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), - "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), - "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), - "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), - "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), - "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), - "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111) - : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), "l"(desc_b), "r"(e), "n"(int32_t(spsel)), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x224x64_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x128x64_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x240x64 TN F16+=E5M2*E4M3 +// SPARSE GMMA 64x192x64 TN F16+=E4M3*E5M2 template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, GMMA::SparseSel spsel = GMMA::SparseSel::Zero > -struct GMMA_64x240x64_F16E5M2E4M3_SS_TN +struct GMMA_64x192x64_F16E4M3E5M2_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; using ERegisters = uint32_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[60]; + using CRegisters = uint32_t[48]; CUTE_HOST_DEVICE static void fma(uint64_t const& desc_a, @@ -47935,30 +17679,26 @@ struct GMMA_64x240x64_F16E5M2E4M3_SS_TN uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, - uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, - uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, - uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, uint32_t const& e, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %64, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n240k64.f16.e5m2.e4m3 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59}," - " %60," - " %61," - " %62, %63," - " p, %65, %66;\n" + "setp.ne.b32 p, %52, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n192k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + " %48," + " %49," + " %50, %51," + " p, %53, %54;\n" "}\n" : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), @@ -47971,37 +17711,32 @@ struct GMMA_64x240x64_F16E5M2E4M3_SS_TN "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), - "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), - "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), - "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), - "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59) + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) : "l"(desc_a), "l"(desc_b), "r"(e), "n"(int32_t(spsel)), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x240x64_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x192x64_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x240x64 TN F16+=E5M2*E4M3 +// SPARSE GMMA 64x192x64 TN F16+=E4M3*E5M2 template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, GMMA::SparseSel spsel = GMMA::SparseSel::Zero > -struct GMMA_64x240x64_F16E5M2E4M3_RS_TN +struct GMMA_64x192x64_F16E4M3E5M2_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; using ERegisters = uint32_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[60]; + using CRegisters = uint32_t[48]; CUTE_HOST_DEVICE static void fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, @@ -48018,30 +17753,26 @@ struct GMMA_64x240x64_F16E5M2E4M3_RS_TN uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, - uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, - uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, - uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, uint32_t const& e, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %67, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n240k64.f16.e5m2.e4m3 " + "setp.ne.b32 p, %55, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n192k64.f16.e4m3.e5m2 " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " " %16, %17, %18, %19, %20, %21, %22, %23, " " %24, %25, %26, %27, %28, %29, %30, %31, " " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59}," - "{%60, %61, %62, %63}," - " %64," - " %65, %66," - " p, %68, %69;\n" + " %40, %41, %42, %43, %44, %45, %46, %47}," + "{%48, %49, %50, %51}," + " %52," + " %53, %54," + " p, %56, %57;\n" "}\n" : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), @@ -48054,80 +17785,70 @@ struct GMMA_64x240x64_F16E5M2E4M3_RS_TN "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), - "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), - "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), - "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), - "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59) + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) : "r"(a00), "r"(a01), "r"(a02), "r"(a03), "l"(desc_b), "r"(e), "n"(int32_t(spsel)), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x240x64_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x192x64_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x240x64 TN F32+=E5M2*E4M3 +// SPARSE GMMA 64x192x64 TN F32+=E4M3*E5M2 template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, GMMA::SparseSel spsel = GMMA::SparseSel::Zero > -struct GMMA_64x240x64_F32E5M2E4M3_SS_TN +struct GMMA_64x192x64_F32E4M3E5M2_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; using ERegisters = uint32_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = float[120]; + using CRegisters = float[96]; CUTE_HOST_DEVICE static void fma(uint64_t const& desc_a, uint64_t const& desc_b, - float & d000, float & d001, float & d002, float & d003, - float & d004, float & d005, float & d006, float & d007, - float & d008, float & d009, float & d010, float & d011, - float & d012, float & d013, float & d014, float & d015, - float & d016, float & d017, float & d018, float & d019, - float & d020, float & d021, float & d022, float & d023, - float & d024, float & d025, float & d026, float & d027, - float & d028, float & d029, float & d030, float & d031, - float & d032, float & d033, float & d034, float & d035, - float & d036, float & d037, float & d038, float & d039, - float & d040, float & d041, float & d042, float & d043, - float & d044, float & d045, float & d046, float & d047, - float & d048, float & d049, float & d050, float & d051, - float & d052, float & d053, float & d054, float & d055, - float & d056, float & d057, float & d058, float & d059, - float & d060, float & d061, float & d062, float & d063, - float & d064, float & d065, float & d066, float & d067, - float & d068, float & d069, float & d070, float & d071, - float & d072, float & d073, float & d074, float & d075, - float & d076, float & d077, float & d078, float & d079, - float & d080, float & d081, float & d082, float & d083, - float & d084, float & d085, float & d086, float & d087, - float & d088, float & d089, float & d090, float & d091, - float & d092, float & d093, float & d094, float & d095, - float & d096, float & d097, float & d098, float & d099, - float & d100, float & d101, float & d102, float & d103, - float & d104, float & d105, float & d106, float & d107, - float & d108, float & d109, float & d110, float & d111, - float & d112, float & d113, float & d114, float & d115, - float & d116, float & d117, float & d118, float & d119, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + float & d88, float & d89, float & d90, float & d91, + float & d92, float & d93, float & d94, float & d95, uint32_t const& e, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %124, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n240k64.f32.e5m2.e4m3 " + "setp.ne.b32 p, %100, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n192k64.f32.e4m3.e5m2 " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " " %16, %17, %18, %19, %20, %21, %22, %23, " @@ -48139,115 +17860,99 @@ struct GMMA_64x240x64_F32E5M2E4M3_SS_TN " %64, %65, %66, %67, %68, %69, %70, %71, " " %72, %73, %74, %75, %76, %77, %78, %79, " " %80, %81, %82, %83, %84, %85, %86, %87, " - " %88, %89, %90, %91, %92, %93, %94, %95, " - " %96, %97, %98, %99, %100, %101, %102, %103, " - " %104, %105, %106, %107, %108, %109, %110, %111, " - " %112, %113, %114, %115, %116, %117, %118, %119}," - " %120," - " %121," - " %122, %123," - " p, %125, %126;\n" + " %88, %89, %90, %91, %92, %93, %94, %95}," + " %96," + " %97," + " %98, %99," + " p, %101, %102;\n" "}\n" - : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), - "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), - "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), - "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), - "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), - "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), - "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), - "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), - "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), - "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), - "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), - "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), - "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), - "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), - "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), - "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), - "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), - "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), - "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), - "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), - "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), - "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), - "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), - "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), - "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), - "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), - "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), - "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), - "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), - "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119) + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), + "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91), + "+f"(d92), "+f"(d93), "+f"(d94), "+f"(d95) : "l"(desc_a), "l"(desc_b), "r"(e), "n"(int32_t(spsel)), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x240x64_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x192x64_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x240x64 TN F32+=E5M2*E4M3 +// SPARSE GMMA 64x192x64 TN F32+=E4M3*E5M2 template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, GMMA::SparseSel spsel = GMMA::SparseSel::Zero > -struct GMMA_64x240x64_F32E5M2E4M3_RS_TN +struct GMMA_64x192x64_F32E4M3E5M2_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; using ERegisters = uint32_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = float[120]; + using CRegisters = float[96]; CUTE_HOST_DEVICE static void - fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, uint64_t const& desc_b, - float & d000, float & d001, float & d002, float & d003, - float & d004, float & d005, float & d006, float & d007, - float & d008, float & d009, float & d010, float & d011, - float & d012, float & d013, float & d014, float & d015, - float & d016, float & d017, float & d018, float & d019, - float & d020, float & d021, float & d022, float & d023, - float & d024, float & d025, float & d026, float & d027, - float & d028, float & d029, float & d030, float & d031, - float & d032, float & d033, float & d034, float & d035, - float & d036, float & d037, float & d038, float & d039, - float & d040, float & d041, float & d042, float & d043, - float & d044, float & d045, float & d046, float & d047, - float & d048, float & d049, float & d050, float & d051, - float & d052, float & d053, float & d054, float & d055, - float & d056, float & d057, float & d058, float & d059, - float & d060, float & d061, float & d062, float & d063, - float & d064, float & d065, float & d066, float & d067, - float & d068, float & d069, float & d070, float & d071, - float & d072, float & d073, float & d074, float & d075, - float & d076, float & d077, float & d078, float & d079, - float & d080, float & d081, float & d082, float & d083, - float & d084, float & d085, float & d086, float & d087, - float & d088, float & d089, float & d090, float & d091, - float & d092, float & d093, float & d094, float & d095, - float & d096, float & d097, float & d098, float & d099, - float & d100, float & d101, float & d102, float & d103, - float & d104, float & d105, float & d106, float & d107, - float & d108, float & d109, float & d110, float & d111, - float & d112, float & d113, float & d114, float & d115, - float & d116, float & d117, float & d118, float & d119, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + float & d88, float & d89, float & d90, float & d91, + float & d92, float & d93, float & d94, float & d95, uint32_t const& e, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %127, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n240k64.f32.e5m2.e4m3 " + "setp.ne.b32 p, %103, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n192k64.f32.e4m3.e5m2 " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " " %16, %17, %18, %19, %20, %21, %22, %23, " @@ -48259,65 +17964,55 @@ struct GMMA_64x240x64_F32E5M2E4M3_RS_TN " %64, %65, %66, %67, %68, %69, %70, %71, " " %72, %73, %74, %75, %76, %77, %78, %79, " " %80, %81, %82, %83, %84, %85, %86, %87, " - " %88, %89, %90, %91, %92, %93, %94, %95, " - " %96, %97, %98, %99, %100, %101, %102, %103, " - " %104, %105, %106, %107, %108, %109, %110, %111, " - " %112, %113, %114, %115, %116, %117, %118, %119}," - "{%120, %121, %122, %123}," - " %124," - " %125, %126," - " p, %128, %129;\n" + " %88, %89, %90, %91, %92, %93, %94, %95}," + "{%96, %97, %98, %99}," + " %100," + " %101, %102," + " p, %104, %105;\n" "}\n" - : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), - "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), - "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), - "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), - "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), - "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), - "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), - "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), - "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), - "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), - "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), - "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), - "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), - "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), - "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), - "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), - "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), - "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), - "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), - "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), - "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), - "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), - "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), - "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), - "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), - "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), - "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), - "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), - "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), - "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119) - : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), + "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91), + "+f"(d92), "+f"(d93), "+f"(d94), "+f"(d95) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), "l"(desc_b), "r"(e), "n"(int32_t(spsel)), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x240x64_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x192x64_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -// SPARSE GMMA 64x256x64 TN F16+=E5M2*E4M3 +// SPARSE GMMA 64x256x64 TN F16+=E4M3*E5M2 template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, GMMA::SparseSel spsel = GMMA::SparseSel::Zero > -struct GMMA_64x256x64_F16E5M2E4M3_SS_TN +struct GMMA_64x256x64_F16E4M3E5M2_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -48348,11 +18043,12 @@ struct GMMA_64x256x64_F16E5M2E4M3_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" "setp.ne.b32 p, %68, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n256k64.f16.e5m2.e4m3 " + "wgmma.mma_async.sp.sync.aligned.m64n256k64.f16.e4m3.e5m2 " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " " %16, %17, %18, %19, %20, %21, %22, %23, " @@ -48387,20 +18083,20 @@ struct GMMA_64x256x64_F16E5M2E4M3_SS_TN "r"(e), "n"(int32_t(spsel)), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x256x64_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x256x64_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// SPARSE GMMA 64x256x64 TN F16+=E5M2*E4M3 +// SPARSE GMMA 64x256x64 TN F16+=E4M3*E5M2 template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, GMMA::SparseSel spsel = GMMA::SparseSel::Zero > -struct GMMA_64x256x64_F16E5M2E4M3_RS_TN +struct GMMA_64x256x64_F16E4M3E5M2_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -48431,11 +18127,12 @@ struct GMMA_64x256x64_F16E5M2E4M3_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" "setp.ne.b32 p, %71, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n256k64.f16.e5m2.e4m3 " + "wgmma.mma_async.sp.sync.aligned.m64n256k64.f16.e4m3.e5m2 " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " " %16, %17, %18, %19, %20, %21, %22, %23, " @@ -48470,20 +18167,20 @@ struct GMMA_64x256x64_F16E5M2E4M3_RS_TN "r"(e), "n"(int32_t(spsel)), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x256x64_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x256x64_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// SPARSE GMMA 64x256x64 TN F32+=E5M2*E4M3 +// SPARSE GMMA 64x256x64 TN F32+=E4M3*E5M2 template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, GMMA::SparseSel spsel = GMMA::SparseSel::Zero > -struct GMMA_64x256x64_F32E5M2E4M3_SS_TN +struct GMMA_64x256x64_F32E4M3E5M2_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -48530,11 +18227,12 @@ struct GMMA_64x256x64_F32E5M2E4M3_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" "setp.ne.b32 p, %132, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n256k64.f32.e5m2.e4m3 " + "wgmma.mma_async.sp.sync.aligned.m64n256k64.f32.e4m3.e5m2 " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " " %16, %17, %18, %19, %20, %21, %22, %23, " @@ -48593,20 +18291,20 @@ struct GMMA_64x256x64_F32E5M2E4M3_SS_TN "r"(e), "n"(int32_t(spsel)), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x256x64_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x256x64_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// SPARSE GMMA 64x256x64 TN F32+=E5M2*E4M3 +// SPARSE GMMA 64x256x64 TN F32+=E4M3*E5M2 template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, GMMA::SparseSel spsel = GMMA::SparseSel::Zero > -struct GMMA_64x256x64_F32E5M2E4M3_RS_TN +struct GMMA_64x256x64_F32E4M3E5M2_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -48653,11 +18351,12 @@ struct GMMA_64x256x64_F32E5M2E4M3_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" "setp.ne.b32 p, %135, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n256k64.f32.e5m2.e4m3 " + "wgmma.mma_async.sp.sync.aligned.m64n256k64.f32.e4m3.e5m2 " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " " %16, %17, %18, %19, %20, %21, %22, %23, " @@ -48716,20 +18415,20 @@ struct GMMA_64x256x64_F32E5M2E4M3_RS_TN "r"(e), "n"(int32_t(spsel)), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x256x64_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x256x64_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// SPARSE GMMA 64x8x64 TN F16+=E5M2*E5M2 +// SPARSE GMMA 64x8x64 TN F16+=E5M2*E4M3 template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, GMMA::SparseSel spsel = GMMA::SparseSel::Zero > -struct GMMA_64x8x64_F16E5M2E5M2_SS_TN +struct GMMA_64x8x64_F16E5M2E4M3_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -48745,11 +18444,12 @@ struct GMMA_64x8x64_F16E5M2E5M2_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" "setp.ne.b32 p, %6, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n8k64.f16.e5m2.e5m2 " + "wgmma.mma_async.sp.sync.aligned.m64n8k64.f16.e5m2.e4m3 " "{%0, %1}," " %2," " %3," @@ -48762,20 +18462,20 @@ struct GMMA_64x8x64_F16E5M2E5M2_SS_TN "r"(e), "n"(int32_t(spsel)), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x8x64_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x8x64_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// SPARSE GMMA 64x8x64 TN F16+=E5M2*E5M2 +// SPARSE GMMA 64x8x64 TN F16+=E5M2*E4M3 template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, GMMA::SparseSel spsel = GMMA::SparseSel::Zero > -struct GMMA_64x8x64_F16E5M2E5M2_RS_TN +struct GMMA_64x8x64_F16E5M2E4M3_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -48791,11 +18491,12 @@ struct GMMA_64x8x64_F16E5M2E5M2_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" "setp.ne.b32 p, %9, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n8k64.f16.e5m2.e5m2 " + "wgmma.mma_async.sp.sync.aligned.m64n8k64.f16.e5m2.e4m3 " "{%0, %1}," "{%2, %3, %4, %5}," " %6," @@ -48808,20 +18509,20 @@ struct GMMA_64x8x64_F16E5M2E5M2_RS_TN "r"(e), "n"(int32_t(spsel)), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x8x64_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x8x64_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// SPARSE GMMA 64x8x64 TN F32+=E5M2*E5M2 +// SPARSE GMMA 64x8x64 TN F32+=E5M2*E4M3 template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, GMMA::SparseSel spsel = GMMA::SparseSel::Zero > -struct GMMA_64x8x64_F32E5M2E5M2_SS_TN +struct GMMA_64x8x64_F32E5M2E4M3_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -48837,11 +18538,12 @@ struct GMMA_64x8x64_F32E5M2E5M2_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" "setp.ne.b32 p, %8, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n8k64.f32.e5m2.e5m2 " + "wgmma.mma_async.sp.sync.aligned.m64n8k64.f32.e5m2.e4m3 " "{%0, %1, %2, %3}," " %4," " %5," @@ -48854,20 +18556,20 @@ struct GMMA_64x8x64_F32E5M2E5M2_SS_TN "r"(e), "n"(int32_t(spsel)), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x8x64_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x8x64_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// SPARSE GMMA 64x8x64 TN F32+=E5M2*E5M2 +// SPARSE GMMA 64x8x64 TN F32+=E5M2*E4M3 template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, GMMA::SparseSel spsel = GMMA::SparseSel::Zero > -struct GMMA_64x8x64_F32E5M2E5M2_RS_TN +struct GMMA_64x8x64_F32E5M2E4M3_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -48883,11 +18585,12 @@ struct GMMA_64x8x64_F32E5M2E5M2_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" "setp.ne.b32 p, %11, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n8k64.f32.e5m2.e5m2 " + "wgmma.mma_async.sp.sync.aligned.m64n8k64.f32.e5m2.e4m3 " "{%0, %1, %2, %3}," "{%4, %5, %6, %7}," " %8," @@ -48900,20 +18603,20 @@ struct GMMA_64x8x64_F32E5M2E5M2_RS_TN "r"(e), "n"(int32_t(spsel)), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x8x64_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x8x64_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// SPARSE GMMA 64x16x64 TN F16+=E5M2*E5M2 +// SPARSE GMMA 64x16x64 TN F16+=E5M2*E4M3 template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, GMMA::SparseSel spsel = GMMA::SparseSel::Zero > -struct GMMA_64x16x64_F16E5M2E5M2_SS_TN +struct GMMA_64x16x64_F16E5M2E4M3_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -48929,11 +18632,12 @@ struct GMMA_64x16x64_F16E5M2E5M2_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" "setp.ne.b32 p, %8, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n16k64.f16.e5m2.e5m2 " + "wgmma.mma_async.sp.sync.aligned.m64n16k64.f16.e5m2.e4m3 " "{%0, %1, %2, %3}," " %4," " %5," @@ -48946,20 +18650,20 @@ struct GMMA_64x16x64_F16E5M2E5M2_SS_TN "r"(e), "n"(int32_t(spsel)), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x16x64_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x16x64_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// SPARSE GMMA 64x16x64 TN F16+=E5M2*E5M2 +// SPARSE GMMA 64x16x64 TN F16+=E5M2*E4M3 template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, GMMA::SparseSel spsel = GMMA::SparseSel::Zero > -struct GMMA_64x16x64_F16E5M2E5M2_RS_TN +struct GMMA_64x16x64_F16E5M2E4M3_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -48975,11 +18679,12 @@ struct GMMA_64x16x64_F16E5M2E5M2_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" "setp.ne.b32 p, %11, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n16k64.f16.e5m2.e5m2 " + "wgmma.mma_async.sp.sync.aligned.m64n16k64.f16.e5m2.e4m3 " "{%0, %1, %2, %3}," "{%4, %5, %6, %7}," " %8," @@ -48992,20 +18697,20 @@ struct GMMA_64x16x64_F16E5M2E5M2_RS_TN "r"(e), "n"(int32_t(spsel)), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x16x64_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x16x64_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// SPARSE GMMA 64x16x64 TN F32+=E5M2*E5M2 +// SPARSE GMMA 64x16x64 TN F32+=E5M2*E4M3 template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, GMMA::SparseSel spsel = GMMA::SparseSel::Zero > -struct GMMA_64x16x64_F32E5M2E5M2_SS_TN +struct GMMA_64x16x64_F32E5M2E4M3_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -49022,11 +18727,12 @@ struct GMMA_64x16x64_F32E5M2E5M2_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" "setp.ne.b32 p, %12, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n16k64.f32.e5m2.e5m2 " + "wgmma.mma_async.sp.sync.aligned.m64n16k64.f32.e5m2.e4m3 " "{%0, %1, %2, %3, %4, %5, %6, %7}," " %8," " %9," @@ -49040,20 +18746,20 @@ struct GMMA_64x16x64_F32E5M2E5M2_SS_TN "r"(e), "n"(int32_t(spsel)), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x16x64_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x16x64_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// SPARSE GMMA 64x16x64 TN F32+=E5M2*E5M2 +// SPARSE GMMA 64x16x64 TN F32+=E5M2*E4M3 template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, GMMA::SparseSel spsel = GMMA::SparseSel::Zero > -struct GMMA_64x16x64_F32E5M2E5M2_RS_TN +struct GMMA_64x16x64_F32E5M2E4M3_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -49070,11 +18776,12 @@ struct GMMA_64x16x64_F32E5M2E5M2_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" "setp.ne.b32 p, %15, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n16k64.f32.e5m2.e5m2 " + "wgmma.mma_async.sp.sync.aligned.m64n16k64.f32.e5m2.e4m3 " "{%0, %1, %2, %3, %4, %5, %6, %7}," "{%8, %9, %10, %11}," " %12," @@ -49088,20 +18795,20 @@ struct GMMA_64x16x64_F32E5M2E5M2_RS_TN "r"(e), "n"(int32_t(spsel)), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x16x64_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x16x64_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// SPARSE GMMA 64x32x64 TN F16+=E5M2*E5M2 +// SPARSE GMMA 64x32x64 TN F16+=E5M2*E4M3 template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, GMMA::SparseSel spsel = GMMA::SparseSel::Zero > -struct GMMA_64x32x64_F16E5M2E5M2_SS_TN +struct GMMA_64x32x64_F16E5M2E4M3_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -49118,11 +18825,12 @@ struct GMMA_64x32x64_F16E5M2E5M2_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" "setp.ne.b32 p, %12, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n32k64.f16.e5m2.e5m2 " + "wgmma.mma_async.sp.sync.aligned.m64n32k64.f16.e5m2.e4m3 " "{%0, %1, %2, %3, %4, %5, %6, %7}," " %8," " %9," @@ -49136,20 +18844,20 @@ struct GMMA_64x32x64_F16E5M2E5M2_SS_TN "r"(e), "n"(int32_t(spsel)), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x32x64_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x32x64_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// SPARSE GMMA 64x32x64 TN F16+=E5M2*E5M2 +// SPARSE GMMA 64x32x64 TN F16+=E5M2*E4M3 template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, GMMA::SparseSel spsel = GMMA::SparseSel::Zero > -struct GMMA_64x32x64_F16E5M2E5M2_RS_TN +struct GMMA_64x32x64_F16E5M2E4M3_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -49166,11 +18874,12 @@ struct GMMA_64x32x64_F16E5M2E5M2_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" "setp.ne.b32 p, %15, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n32k64.f16.e5m2.e5m2 " + "wgmma.mma_async.sp.sync.aligned.m64n32k64.f16.e5m2.e4m3 " "{%0, %1, %2, %3, %4, %5, %6, %7}," "{%8, %9, %10, %11}," " %12," @@ -49184,20 +18893,20 @@ struct GMMA_64x32x64_F16E5M2E5M2_RS_TN "r"(e), "n"(int32_t(spsel)), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x32x64_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x32x64_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// SPARSE GMMA 64x32x64 TN F32+=E5M2*E5M2 +// SPARSE GMMA 64x32x64 TN F32+=E5M2*E4M3 template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, GMMA::SparseSel spsel = GMMA::SparseSel::Zero > -struct GMMA_64x32x64_F32E5M2E5M2_SS_TN +struct GMMA_64x32x64_F32E5M2E4M3_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -49216,11 +18925,12 @@ struct GMMA_64x32x64_F32E5M2E5M2_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" "setp.ne.b32 p, %20, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n32k64.f32.e5m2.e5m2 " + "wgmma.mma_async.sp.sync.aligned.m64n32k64.f32.e5m2.e4m3 " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15}," " %16," @@ -49237,20 +18947,20 @@ struct GMMA_64x32x64_F32E5M2E5M2_SS_TN "r"(e), "n"(int32_t(spsel)), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x32x64_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x32x64_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// SPARSE GMMA 64x32x64 TN F32+=E5M2*E5M2 +// SPARSE GMMA 64x32x64 TN F32+=E5M2*E4M3 template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, GMMA::SparseSel spsel = GMMA::SparseSel::Zero > -struct GMMA_64x32x64_F32E5M2E5M2_RS_TN +struct GMMA_64x32x64_F32E5M2E4M3_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -49269,11 +18979,12 @@ struct GMMA_64x32x64_F32E5M2E5M2_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" "setp.ne.b32 p, %23, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n32k64.f32.e5m2.e5m2 " + "wgmma.mma_async.sp.sync.aligned.m64n32k64.f32.e5m2.e4m3 " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15}," "{%16, %17, %18, %19}," @@ -49290,246 +19001,20 @@ struct GMMA_64x32x64_F32E5M2E5M2_RS_TN "r"(e), "n"(int32_t(spsel)), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x32x64_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x48x64 TN F16+=E5M2*E5M2 -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x48x64_F16E5M2E5M2_SS_TN -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[12]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %16, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n48k64.f16.e5m2.e5m2 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11}," - " %12," - " %13," - " %14, %15," - " p, %17, %18;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11) - : "l"(desc_a), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x48x64_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x48x64 TN F16+=E5M2*E5M2 -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x48x64_F16E5M2E5M2_RS_TN -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[12]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, - uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %19, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n48k64.f16.e5m2.e5m2 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11}," - "{%12, %13, %14, %15}," - " %16," - " %17, %18," - " p, %20, %21;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x48x64_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x48x64 TN F32+=E5M2*E5M2 -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x48x64_F32E5M2E5M2_SS_TN -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = float[24]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - float & d00, float & d01, float & d02, float & d03, - float & d04, float & d05, float & d06, float & d07, - float & d08, float & d09, float & d10, float & d11, - float & d12, float & d13, float & d14, float & d15, - float & d16, float & d17, float & d18, float & d19, - float & d20, float & d21, float & d22, float & d23, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %28, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n48k64.f32.e5m2.e5m2 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23}," - " %24," - " %25," - " %26, %27," - " p, %29, %30;\n" - "}\n" - : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), - "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), - "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), - "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), - "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), - "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23) - : "l"(desc_a), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x48x64_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x48x64 TN F32+=E5M2*E5M2 -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x48x64_F32E5M2E5M2_RS_TN -{ - using DRegisters = void; - using ARegisters = uint32_t[4]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = float[24]; - - CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, - uint64_t const& desc_b, - float & d00, float & d01, float & d02, float & d03, - float & d04, float & d05, float & d06, float & d07, - float & d08, float & d09, float & d10, float & d11, - float & d12, float & d13, float & d14, float & d15, - float & d16, float & d17, float & d18, float & d19, - float & d20, float & d21, float & d22, float & d23, - uint32_t const& e, - GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) - { -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %31, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n48k64.f32.e5m2.e5m2 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23}," - "{%24, %25, %26, %27}," - " %28," - " %29, %30," - " p, %32, %33;\n" - "}\n" - : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), - "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), - "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), - "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), - "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), - "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), - "l"(desc_b), - "r"(e), "n"(int32_t(spsel)), - "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x48x64_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x32x64_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -// SPARSE GMMA 64x64x64 TN F16+=E5M2*E5M2 +// SPARSE GMMA 64x64x64 TN F16+=E5M2*E4M3 template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, GMMA::SparseSel spsel = GMMA::SparseSel::Zero > -struct GMMA_64x64x64_F16E5M2E5M2_SS_TN +struct GMMA_64x64x64_F16E5M2E4M3_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -49548,11 +19033,12 @@ struct GMMA_64x64x64_F16E5M2E5M2_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" "setp.ne.b32 p, %20, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n64k64.f16.e5m2.e5m2 " + "wgmma.mma_async.sp.sync.aligned.m64n64k64.f16.e5m2.e4m3 " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15}," " %16," @@ -49569,20 +19055,20 @@ struct GMMA_64x64x64_F16E5M2E5M2_SS_TN "r"(e), "n"(int32_t(spsel)), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x64x64_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x64x64_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// SPARSE GMMA 64x64x64 TN F16+=E5M2*E5M2 +// SPARSE GMMA 64x64x64 TN F16+=E5M2*E4M3 template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, GMMA::SparseSel spsel = GMMA::SparseSel::Zero > -struct GMMA_64x64x64_F16E5M2E5M2_RS_TN +struct GMMA_64x64x64_F16E5M2E4M3_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -49601,11 +19087,12 @@ struct GMMA_64x64x64_F16E5M2E5M2_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" "setp.ne.b32 p, %23, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n64k64.f16.e5m2.e5m2 " + "wgmma.mma_async.sp.sync.aligned.m64n64k64.f16.e5m2.e4m3 " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15}," "{%16, %17, %18, %19}," @@ -49622,20 +19109,20 @@ struct GMMA_64x64x64_F16E5M2E5M2_RS_TN "r"(e), "n"(int32_t(spsel)), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x64x64_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x64x64_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// SPARSE GMMA 64x64x64 TN F32+=E5M2*E5M2 +// SPARSE GMMA 64x64x64 TN F32+=E5M2*E4M3 template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, GMMA::SparseSel spsel = GMMA::SparseSel::Zero > -struct GMMA_64x64x64_F32E5M2E5M2_SS_TN +struct GMMA_64x64x64_F32E5M2E4M3_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; @@ -49658,11 +19145,12 @@ struct GMMA_64x64x64_F32E5M2E5M2_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" "setp.ne.b32 p, %36, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n64k64.f32.e5m2.e5m2 " + "wgmma.mma_async.sp.sync.aligned.m64n64k64.f32.e5m2.e4m3 " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " " %16, %17, %18, %19, %20, %21, %22, %23, " @@ -49685,20 +19173,20 @@ struct GMMA_64x64x64_F32E5M2E5M2_SS_TN "r"(e), "n"(int32_t(spsel)), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x64x64_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x64x64_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// SPARSE GMMA 64x64x64 TN F32+=E5M2*E5M2 +// SPARSE GMMA 64x64x64 TN F32+=E5M2*E4M3 template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, GMMA::SparseSel spsel = GMMA::SparseSel::Zero > -struct GMMA_64x64x64_F32E5M2E5M2_RS_TN +struct GMMA_64x64x64_F32E5M2E4M3_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; @@ -49721,11 +19209,12 @@ struct GMMA_64x64x64_F32E5M2E5M2_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" "setp.ne.b32 p, %39, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n64k64.f32.e5m2.e5m2 " + "wgmma.mma_async.sp.sync.aligned.m64n64k64.f32.e5m2.e4m3 " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " " %16, %17, %18, %19, %20, %21, %22, %23, " @@ -49748,27 +19237,26 @@ struct GMMA_64x64x64_F32E5M2E5M2_RS_TN "r"(e), "n"(int32_t(spsel)), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x64x64_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x64x64_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x80x64 TN F16+=E5M2*E5M2 +// SPARSE GMMA 64x96x64 TN F16+=E5M2*E4M3 template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, GMMA::SparseSel spsel = GMMA::SparseSel::Zero > -struct GMMA_64x80x64_F16E5M2E5M2_SS_TN +struct GMMA_64x96x64_F16E5M2E4M3_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; using ERegisters = uint32_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[20]; + using CRegisters = uint32_t[24]; CUTE_HOST_DEVICE static void fma(uint64_t const& desc_a, @@ -49778,55 +19266,56 @@ struct GMMA_64x80x64_F16E5M2E5M2_SS_TN uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, uint32_t const& e, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %24, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n80k64.f16.e5m2.e5m2 " + "setp.ne.b32 p, %28, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n96k64.f16.e5m2.e4m3 " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19}," - " %20," - " %21," - " %22, %23," - " p, %25, %26;\n" + " %16, %17, %18, %19, %20, %21, %22, %23}," + " %24," + " %25," + " %26, %27," + " p, %29, %30;\n" "}\n" : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19) + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) : "l"(desc_a), "l"(desc_b), "r"(e), "n"(int32_t(spsel)), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x80x64_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x96x64_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x80x64 TN F16+=E5M2*E5M2 +// SPARSE GMMA 64x96x64 TN F16+=E5M2*E4M3 template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, GMMA::SparseSel spsel = GMMA::SparseSel::Zero > -struct GMMA_64x80x64_F16E5M2E5M2_RS_TN +struct GMMA_64x96x64_F16E5M2E4M3_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; using ERegisters = uint32_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[20]; + using CRegisters = uint32_t[24]; CUTE_HOST_DEVICE static void fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, @@ -49836,55 +19325,56 @@ struct GMMA_64x80x64_F16E5M2E5M2_RS_TN uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, uint32_t const& e, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %27, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n80k64.f16.e5m2.e5m2 " + "setp.ne.b32 p, %31, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n96k64.f16.e5m2.e4m3 " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19}," - "{%20, %21, %22, %23}," - " %24," - " %25, %26," - " p, %28, %29;\n" + " %16, %17, %18, %19, %20, %21, %22, %23}," + "{%24, %25, %26, %27}," + " %28," + " %29, %30," + " p, %32, %33;\n" "}\n" : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19) + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) : "r"(a00), "r"(a01), "r"(a02), "r"(a03), "l"(desc_b), "r"(e), "n"(int32_t(spsel)), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x80x64_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x96x64_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x80x64 TN F32+=E5M2*E5M2 +// SPARSE GMMA 64x96x64 TN F32+=E5M2*E4M3 template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, GMMA::SparseSel spsel = GMMA::SparseSel::Zero > -struct GMMA_64x80x64_F32E5M2E5M2_SS_TN +struct GMMA_64x96x64_F32E5M2E4M3_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; using ERegisters = uint32_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = float[40]; + using CRegisters = float[48]; CUTE_HOST_DEVICE static void fma(uint64_t const& desc_a, @@ -49899,24 +19389,28 @@ struct GMMA_64x80x64_F32E5M2E5M2_SS_TN float & d28, float & d29, float & d30, float & d31, float & d32, float & d33, float & d34, float & d35, float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, uint32_t const& e, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %44, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n80k64.f32.e5m2.e5m2 " + "setp.ne.b32 p, %52, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n96k64.f32.e5m2.e4m3 " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " " %16, %17, %18, %19, %20, %21, %22, %23, " " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39}," - " %40," - " %41," - " %42, %43," - " p, %45, %46;\n" + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + " %48," + " %49," + " %50, %51," + " p, %53, %54;\n" "}\n" : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), @@ -49927,34 +19421,34 @@ struct GMMA_64x80x64_F32E5M2E5M2_SS_TN "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), - "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39) + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47) : "l"(desc_a), "l"(desc_b), "r"(e), "n"(int32_t(spsel)), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x80x64_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x96x64_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x80x64 TN F32+=E5M2*E5M2 +// SPARSE GMMA 64x96x64 TN F32+=E5M2*E4M3 template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, GMMA::SparseSel spsel = GMMA::SparseSel::Zero > -struct GMMA_64x80x64_F32E5M2E5M2_RS_TN +struct GMMA_64x96x64_F32E5M2E4M3_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; using ERegisters = uint32_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = float[40]; + using CRegisters = float[48]; CUTE_HOST_DEVICE static void fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, @@ -49969,24 +19463,28 @@ struct GMMA_64x80x64_F32E5M2E5M2_RS_TN float & d28, float & d29, float & d30, float & d31, float & d32, float & d33, float & d34, float & d35, float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, uint32_t const& e, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %47, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n80k64.f32.e5m2.e5m2 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39}," - "{%40, %41, %42, %43}," - " %44," - " %45, %46," - " p, %48, %49;\n" + "setp.ne.b32 p, %55, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n96k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + "{%48, %49, %50, %51}," + " %52," + " %53, %54," + " p, %56, %57;\n" "}\n" : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), @@ -49997,33 +19495,34 @@ struct GMMA_64x80x64_F32E5M2E5M2_RS_TN "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), - "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39) + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47) : "r"(a00), "r"(a01), "r"(a02), "r"(a03), "l"(desc_b), "r"(e), "n"(int32_t(spsel)), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x80x64_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x96x64_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -// SPARSE GMMA 64x96x64 TN F16+=E5M2*E5M2 +// SPARSE GMMA 64x128x64 TN F16+=E5M2*E4M3 template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, GMMA::SparseSel spsel = GMMA::SparseSel::Zero > -struct GMMA_64x96x64_F16E5M2E5M2_SS_TN +struct GMMA_64x128x64_F16E5M2E4M3_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; using ERegisters = uint32_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[24]; + using CRegisters = uint32_t[32]; CUTE_HOST_DEVICE static void fma(uint64_t const& desc_a, @@ -50034,54 +19533,60 @@ struct GMMA_64x96x64_F16E5M2E5M2_SS_TN uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, uint32_t const& e, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %28, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n96k64.f16.e5m2.e5m2 " + "setp.ne.b32 p, %36, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n128k64.f16.e5m2.e4m3 " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23}," - " %24," - " %25," - " %26, %27," - " p, %29, %30;\n" + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + " %32," + " %33," + " %34, %35," + " p, %37, %38;\n" "}\n" : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) : "l"(desc_a), "l"(desc_b), "r"(e), "n"(int32_t(spsel)), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x96x64_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x128x64_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// SPARSE GMMA 64x96x64 TN F16+=E5M2*E5M2 +// SPARSE GMMA 64x128x64 TN F16+=E5M2*E4M3 template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, GMMA::SparseSel spsel = GMMA::SparseSel::Zero > -struct GMMA_64x96x64_F16E5M2E5M2_RS_TN +struct GMMA_64x128x64_F16E5M2E4M3_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; using ERegisters = uint32_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[24]; + using CRegisters = uint32_t[32]; CUTE_HOST_DEVICE static void fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, @@ -50092,54 +19597,60 @@ struct GMMA_64x96x64_F16E5M2E5M2_RS_TN uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, uint32_t const& e, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %31, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n96k64.f16.e5m2.e5m2 " + "setp.ne.b32 p, %39, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n128k64.f16.e5m2.e4m3 " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23}," - "{%24, %25, %26, %27}," - " %28," - " %29, %30," - " p, %32, %33;\n" + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + "{%32, %33, %34, %35}," + " %36," + " %37, %38," + " p, %40, %41;\n" "}\n" : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) : "r"(a00), "r"(a01), "r"(a02), "r"(a03), "l"(desc_b), "r"(e), "n"(int32_t(spsel)), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x96x64_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x128x64_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// SPARSE GMMA 64x96x64 TN F32+=E5M2*E5M2 +// SPARSE GMMA 64x128x64 TN F32+=E5M2*E4M3 template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, GMMA::SparseSel spsel = GMMA::SparseSel::Zero > -struct GMMA_64x96x64_F32E5M2E5M2_SS_TN +struct GMMA_64x128x64_F32E5M2E4M3_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; using ERegisters = uint32_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = float[48]; + using CRegisters = float[64]; CUTE_HOST_DEVICE static void fma(uint64_t const& desc_a, @@ -50156,25 +19667,32 @@ struct GMMA_64x96x64_F32E5M2E5M2_SS_TN float & d36, float & d37, float & d38, float & d39, float & d40, float & d41, float & d42, float & d43, float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, uint32_t const& e, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %52, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n96k64.f32.e5m2.e5m2 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47}," - " %48," - " %49," - " %50, %51," - " p, %53, %54;\n" + "setp.ne.b32 p, %68, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n128k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + " %64," + " %65," + " %66, %67," + " p, %69, %70;\n" "}\n" : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), @@ -50187,32 +19705,36 @@ struct GMMA_64x96x64_F32E5M2E5M2_SS_TN "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), - "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47) + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63) : "l"(desc_a), "l"(desc_b), "r"(e), "n"(int32_t(spsel)), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x96x64_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x128x64_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// SPARSE GMMA 64x96x64 TN F32+=E5M2*E5M2 +// SPARSE GMMA 64x128x64 TN F32+=E5M2*E4M3 template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, GMMA::SparseSel spsel = GMMA::SparseSel::Zero > -struct GMMA_64x96x64_F32E5M2E5M2_RS_TN +struct GMMA_64x128x64_F32E5M2E4M3_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; using ERegisters = uint32_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = float[48]; + using CRegisters = float[64]; CUTE_HOST_DEVICE static void fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, @@ -50229,25 +19751,32 @@ struct GMMA_64x96x64_F32E5M2E5M2_RS_TN float & d36, float & d37, float & d38, float & d39, float & d40, float & d41, float & d42, float & d43, float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, uint32_t const& e, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %55, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n96k64.f32.e5m2.e5m2 " + "setp.ne.b32 p, %71, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n128k64.f32.e5m2.e4m3 " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " " %16, %17, %18, %19, %20, %21, %22, %23, " " %24, %25, %26, %27, %28, %29, %30, %31, " " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47}," - "{%48, %49, %50, %51}," - " %52," - " %53, %54," - " p, %56, %57;\n" + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + "{%64, %65, %66, %67}," + " %68," + " %69, %70," + " p, %72, %73;\n" "}\n" : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), @@ -50260,33 +19789,36 @@ struct GMMA_64x96x64_F32E5M2E5M2_RS_TN "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), - "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47) + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63) : "r"(a00), "r"(a01), "r"(a02), "r"(a03), "l"(desc_b), "r"(e), "n"(int32_t(spsel)), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x96x64_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x128x64_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x112x64 TN F16+=E5M2*E5M2 +// SPARSE GMMA 64x192x64 TN F16+=E5M2*E4M3 template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, GMMA::SparseSel spsel = GMMA::SparseSel::Zero > -struct GMMA_64x112x64_F16E5M2E5M2_SS_TN +struct GMMA_64x192x64_F16E5M2E4M3_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; using ERegisters = uint32_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[28]; + using CRegisters = uint32_t[48]; CUTE_HOST_DEVICE static void fma(uint64_t const& desc_a, @@ -50298,23 +19830,31 @@ struct GMMA_64x112x64_F16E5M2E5M2_SS_TN uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, uint32_t const& e, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %32, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n112k64.f16.e5m2.e5m2 " + "setp.ne.b32 p, %52, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n192k64.f16.e5m2.e4m3 " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27}," - " %28," - " %29," - " %30, %31," - " p, %33, %34;\n" + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + " %48," + " %49," + " %50, %51," + " p, %53, %54;\n" "}\n" : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), @@ -50322,34 +19862,37 @@ struct GMMA_64x112x64_F16E5M2E5M2_SS_TN "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27) + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) : "l"(desc_a), "l"(desc_b), "r"(e), "n"(int32_t(spsel)), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x112x64_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x192x64_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x112x64 TN F16+=E5M2*E5M2 +// SPARSE GMMA 64x192x64 TN F16+=E5M2*E4M3 template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, GMMA::SparseSel spsel = GMMA::SparseSel::Zero > -struct GMMA_64x112x64_F16E5M2E5M2_RS_TN +struct GMMA_64x192x64_F16E5M2E4M3_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; using ERegisters = uint32_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[28]; + using CRegisters = uint32_t[48]; CUTE_HOST_DEVICE static void fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, @@ -50361,23 +19904,31 @@ struct GMMA_64x112x64_F16E5M2E5M2_RS_TN uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, uint32_t const& e, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %35, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n112k64.f16.e5m2.e5m2 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27}," - "{%28, %29, %30, %31}," - " %32," - " %33, %34," - " p, %36, %37;\n" + "setp.ne.b32 p, %55, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n192k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + "{%48, %49, %50, %51}," + " %52," + " %53, %54," + " p, %56, %57;\n" "}\n" : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), @@ -50385,34 +19936,37 @@ struct GMMA_64x112x64_F16E5M2E5M2_RS_TN "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27) + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) : "r"(a00), "r"(a01), "r"(a02), "r"(a03), "l"(desc_b), "r"(e), "n"(int32_t(spsel)), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x112x64_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x192x64_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x112x64 TN F32+=E5M2*E5M2 +// SPARSE GMMA 64x192x64 TN F32+=E5M2*E4M3 template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, GMMA::SparseSel spsel = GMMA::SparseSel::Zero > -struct GMMA_64x112x64_F32E5M2E5M2_SS_TN +struct GMMA_64x192x64_F32E5M2E4M3_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; using ERegisters = uint32_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = float[56]; + using CRegisters = float[96]; CUTE_HOST_DEVICE static void fma(uint64_t const& desc_a, @@ -50431,26 +19985,42 @@ struct GMMA_64x112x64_F32E5M2E5M2_SS_TN float & d44, float & d45, float & d46, float & d47, float & d48, float & d49, float & d50, float & d51, float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + float & d88, float & d89, float & d90, float & d91, + float & d92, float & d93, float & d94, float & d95, uint32_t const& e, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %60, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n112k64.f32.e5m2.e5m2 " + "setp.ne.b32 p, %100, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n192k64.f32.e5m2.e4m3 " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " " %16, %17, %18, %19, %20, %21, %22, %23, " " %24, %25, %26, %27, %28, %29, %30, %31, " " %32, %33, %34, %35, %36, %37, %38, %39, " " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55}," - " %56," - " %57," - " %58, %59," - " p, %61, %62;\n" + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + " %96," + " %97," + " %98, %99," + " p, %101, %102;\n" "}\n" : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), @@ -50465,34 +20035,42 @@ struct GMMA_64x112x64_F32E5M2E5M2_SS_TN "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), - "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55) + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), + "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91), + "+f"(d92), "+f"(d93), "+f"(d94), "+f"(d95) : "l"(desc_a), "l"(desc_b), "r"(e), "n"(int32_t(spsel)), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x112x64_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x192x64_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x112x64 TN F32+=E5M2*E5M2 +// SPARSE GMMA 64x192x64 TN F32+=E5M2*E4M3 template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, GMMA::SparseSel spsel = GMMA::SparseSel::Zero > -struct GMMA_64x112x64_F32E5M2E5M2_RS_TN +struct GMMA_64x192x64_F32E5M2E4M3_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; using ERegisters = uint32_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = float[56]; + using CRegisters = float[96]; CUTE_HOST_DEVICE static void fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, @@ -50511,26 +20089,42 @@ struct GMMA_64x112x64_F32E5M2E5M2_RS_TN float & d44, float & d45, float & d46, float & d47, float & d48, float & d49, float & d50, float & d51, float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + float & d88, float & d89, float & d90, float & d91, + float & d92, float & d93, float & d94, float & d95, uint32_t const& e, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %63, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n112k64.f32.e5m2.e5m2 " + "setp.ne.b32 p, %103, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n192k64.f32.e5m2.e4m3 " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " " %16, %17, %18, %19, %20, %21, %22, %23, " " %24, %25, %26, %27, %28, %29, %30, %31, " " %32, %33, %34, %35, %36, %37, %38, %39, " " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55}," - "{%56, %57, %58, %59}," - " %60," - " %61, %62," - " p, %64, %65;\n" + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + "{%96, %97, %98, %99}," + " %100," + " %101, %102," + " p, %104, %105;\n" "}\n" : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), @@ -50545,33 +20139,42 @@ struct GMMA_64x112x64_F32E5M2E5M2_RS_TN "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), - "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55) + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), + "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91), + "+f"(d92), "+f"(d93), "+f"(d94), "+f"(d95) : "r"(a00), "r"(a01), "r"(a02), "r"(a03), "l"(desc_b), "r"(e), "n"(int32_t(spsel)), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x112x64_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x192x64_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -// SPARSE GMMA 64x128x64 TN F16+=E5M2*E5M2 +// SPARSE GMMA 64x256x64 TN F16+=E5M2*E4M3 template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, GMMA::SparseSel spsel = GMMA::SparseSel::Zero > -struct GMMA_64x128x64_F16E5M2E5M2_SS_TN +struct GMMA_64x256x64_F16E5M2E4M3_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; using ERegisters = uint32_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[32]; + using CRegisters = uint32_t[64]; CUTE_HOST_DEVICE static void fma(uint64_t const& desc_a, @@ -50584,23 +20187,36 @@ struct GMMA_64x128x64_F16E5M2E5M2_SS_TN uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, uint32_t const& e, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %36, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n128k64.f16.e5m2.e5m2 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31}," - " %32," - " %33," - " %34, %35," - " p, %37, %38;\n" + "setp.ne.b32 p, %68, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n256k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + " %64," + " %65," + " %66, %67," + " p, %69, %70;\n" "}\n" : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), @@ -50609,32 +20225,40 @@ struct GMMA_64x128x64_F16E5M2E5M2_SS_TN "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) : "l"(desc_a), "l"(desc_b), "r"(e), "n"(int32_t(spsel)), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x128x64_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x256x64_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// SPARSE GMMA 64x128x64 TN F16+=E5M2*E5M2 +// SPARSE GMMA 64x256x64 TN F16+=E5M2*E4M3 template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, GMMA::SparseSel spsel = GMMA::SparseSel::Zero > -struct GMMA_64x128x64_F16E5M2E5M2_RS_TN +struct GMMA_64x256x64_F16E5M2E4M3_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; using ERegisters = uint32_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[32]; + using CRegisters = uint32_t[64]; CUTE_HOST_DEVICE static void fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, @@ -50647,23 +20271,36 @@ struct GMMA_64x128x64_F16E5M2E5M2_RS_TN uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, uint32_t const& e, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %39, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n128k64.f16.e5m2.e5m2 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31}," - "{%32, %33, %34, %35}," - " %36," - " %37, %38," - " p, %40, %41;\n" + "setp.ne.b32 p, %71, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n256k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + "{%64, %65, %66, %67}," + " %68," + " %69, %70," + " p, %72, %73;\n" "}\n" : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), @@ -50672,61 +20309,86 @@ struct GMMA_64x128x64_F16E5M2E5M2_RS_TN "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) : "r"(a00), "r"(a01), "r"(a02), "r"(a03), "l"(desc_b), "r"(e), "n"(int32_t(spsel)), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x128x64_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x256x64_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// SPARSE GMMA 64x128x64 TN F32+=E5M2*E5M2 +// SPARSE GMMA 64x256x64 TN F32+=E5M2*E4M3 template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, GMMA::SparseSel spsel = GMMA::SparseSel::Zero > -struct GMMA_64x128x64_F32E5M2E5M2_SS_TN +struct GMMA_64x256x64_F32E5M2E4M3_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; using ERegisters = uint32_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = float[64]; + using CRegisters = float[128]; CUTE_HOST_DEVICE static void fma(uint64_t const& desc_a, uint64_t const& desc_b, - float & d00, float & d01, float & d02, float & d03, - float & d04, float & d05, float & d06, float & d07, - float & d08, float & d09, float & d10, float & d11, - float & d12, float & d13, float & d14, float & d15, - float & d16, float & d17, float & d18, float & d19, - float & d20, float & d21, float & d22, float & d23, - float & d24, float & d25, float & d26, float & d27, - float & d28, float & d29, float & d30, float & d31, - float & d32, float & d33, float & d34, float & d35, - float & d36, float & d37, float & d38, float & d39, - float & d40, float & d41, float & d42, float & d43, - float & d44, float & d45, float & d46, float & d47, - float & d48, float & d49, float & d50, float & d51, - float & d52, float & d53, float & d54, float & d55, - float & d56, float & d57, float & d58, float & d59, - float & d60, float & d61, float & d62, float & d63, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + float & d120, float & d121, float & d122, float & d123, + float & d124, float & d125, float & d126, float & d127, uint32_t const& e, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %68, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n128k64.f32.e5m2.e5m2 " + "setp.ne.b32 p, %132, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n256k64.f32.e5m2.e4m3 " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " " %16, %17, %18, %19, %20, %21, %22, %23, " @@ -50734,82 +20396,123 @@ struct GMMA_64x128x64_F32E5M2E5M2_SS_TN " %32, %33, %34, %35, %36, %37, %38, %39, " " %40, %41, %42, %43, %44, %45, %46, %47, " " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63}," - " %64," - " %65," - " %66, %67," - " p, %69, %70;\n" + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + " %128," + " %129," + " %130, %131," + " p, %133, %134;\n" "}\n" - : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), - "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), - "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), - "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), - "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), - "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), - "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), - "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), - "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), - "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), - "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), - "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), - "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), - "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), - "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), - "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63) + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), + "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123), + "+f"(d124), "+f"(d125), "+f"(d126), "+f"(d127) : "l"(desc_a), "l"(desc_b), "r"(e), "n"(int32_t(spsel)), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x128x64_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x256x64_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// SPARSE GMMA 64x128x64 TN F32+=E5M2*E5M2 +// SPARSE GMMA 64x256x64 TN F32+=E5M2*E4M3 template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, GMMA::SparseSel spsel = GMMA::SparseSel::Zero > -struct GMMA_64x128x64_F32E5M2E5M2_RS_TN +struct GMMA_64x256x64_F32E5M2E4M3_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; using ERegisters = uint32_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = float[64]; + using CRegisters = float[128]; CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, uint64_t const& desc_b, - float & d00, float & d01, float & d02, float & d03, - float & d04, float & d05, float & d06, float & d07, - float & d08, float & d09, float & d10, float & d11, - float & d12, float & d13, float & d14, float & d15, - float & d16, float & d17, float & d18, float & d19, - float & d20, float & d21, float & d22, float & d23, - float & d24, float & d25, float & d26, float & d27, - float & d28, float & d29, float & d30, float & d31, - float & d32, float & d33, float & d34, float & d35, - float & d36, float & d37, float & d38, float & d39, - float & d40, float & d41, float & d42, float & d43, - float & d44, float & d45, float & d46, float & d47, - float & d48, float & d49, float & d50, float & d51, - float & d52, float & d53, float & d54, float & d55, - float & d56, float & d57, float & d58, float & d59, - float & d60, float & d61, float & d62, float & d63, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + float & d120, float & d121, float & d122, float & d123, + float & d124, float & d125, float & d126, float & d127, uint32_t const& e, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %71, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n128k64.f32.e5m2.e5m2 " + "setp.ne.b32 p, %135, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n256k64.f32.e5m2.e4m3 " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " " %16, %17, %18, %19, %20, %21, %22, %23, " @@ -50817,846 +20520,555 @@ struct GMMA_64x128x64_F32E5M2E5M2_RS_TN " %32, %33, %34, %35, %36, %37, %38, %39, " " %40, %41, %42, %43, %44, %45, %46, %47, " " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63}," - "{%64, %65, %66, %67}," - " %68," - " %69, %70," - " p, %72, %73;\n" + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + "{%128, %129, %130, %131}," + " %132," + " %133, %134," + " p, %136, %137;\n" "}\n" - : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), - "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), - "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), - "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), - "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), - "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), - "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), - "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), - "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), - "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), - "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), - "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), - "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), - "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), - "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), - "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), + "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123), + "+f"(d124), "+f"(d125), "+f"(d126), "+f"(d127) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), "l"(desc_b), "r"(e), "n"(int32_t(spsel)), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x128x64_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x256x64_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x144x64 TN F16+=E5M2*E5M2 +// SPARSE GMMA 64x8x64 TN F16+=E5M2*E5M2 template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, GMMA::SparseSel spsel = GMMA::SparseSel::Zero > -struct GMMA_64x144x64_F16E5M2E5M2_SS_TN +struct GMMA_64x8x64_F16E5M2E5M2_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; using ERegisters = uint32_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[36]; + using CRegisters = uint32_t[2]; CUTE_HOST_DEVICE static void fma(uint64_t const& desc_a, uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d0, uint32_t & d1, uint32_t const& e, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %40, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n144k64.f16.e5m2.e5m2 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35}," - " %36," - " %37," - " %38, %39," - " p, %41, %42;\n" + "setp.ne.b32 p, %6, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n8k64.f16.e5m2.e5m2 " + "{%0, %1}," + " %2," + " %3," + " %4, %5," + " p, %7, %8;\n" "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35) + : "+r"(d0), "+r"(d1) : "l"(desc_a), "l"(desc_b), "r"(e), "n"(int32_t(spsel)), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x144x64_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x8x64_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x144x64 TN F16+=E5M2*E5M2 +// SPARSE GMMA 64x8x64 TN F16+=E5M2*E5M2 template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, GMMA::SparseSel spsel = GMMA::SparseSel::Zero > -struct GMMA_64x144x64_F16E5M2E5M2_RS_TN +struct GMMA_64x8x64_F16E5M2E5M2_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; using ERegisters = uint32_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[36]; + using CRegisters = uint32_t[2]; CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d0, uint32_t & d1, uint32_t const& e, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %43, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n144k64.f16.e5m2.e5m2 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35}," - "{%36, %37, %38, %39}," - " %40," - " %41, %42," - " p, %44, %45;\n" + "setp.ne.b32 p, %9, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n8k64.f16.e5m2.e5m2 " + "{%0, %1}," + "{%2, %3, %4, %5}," + " %6," + " %7, %8," + " p, %10, %11;\n" "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + : "+r"(d0), "+r"(d1) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), "l"(desc_b), "r"(e), "n"(int32_t(spsel)), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x144x64_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x8x64_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x144x64 TN F32+=E5M2*E5M2 +// SPARSE GMMA 64x8x64 TN F32+=E5M2*E5M2 template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, GMMA::SparseSel spsel = GMMA::SparseSel::Zero > -struct GMMA_64x144x64_F32E5M2E5M2_SS_TN +struct GMMA_64x8x64_F32E5M2E5M2_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; using ERegisters = uint32_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = float[72]; + using CRegisters = float[4]; CUTE_HOST_DEVICE static void fma(uint64_t const& desc_a, uint64_t const& desc_b, - float & d00, float & d01, float & d02, float & d03, - float & d04, float & d05, float & d06, float & d07, - float & d08, float & d09, float & d10, float & d11, - float & d12, float & d13, float & d14, float & d15, - float & d16, float & d17, float & d18, float & d19, - float & d20, float & d21, float & d22, float & d23, - float & d24, float & d25, float & d26, float & d27, - float & d28, float & d29, float & d30, float & d31, - float & d32, float & d33, float & d34, float & d35, - float & d36, float & d37, float & d38, float & d39, - float & d40, float & d41, float & d42, float & d43, - float & d44, float & d45, float & d46, float & d47, - float & d48, float & d49, float & d50, float & d51, - float & d52, float & d53, float & d54, float & d55, - float & d56, float & d57, float & d58, float & d59, - float & d60, float & d61, float & d62, float & d63, - float & d64, float & d65, float & d66, float & d67, - float & d68, float & d69, float & d70, float & d71, + float & d0, float & d1, float & d2, float & d3, uint32_t const& e, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %76, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n144k64.f32.e5m2.e5m2 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71}," - " %72," - " %73," - " %74, %75," - " p, %77, %78;\n" + "setp.ne.b32 p, %8, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n8k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3}," + " %4," + " %5," + " %6, %7," + " p, %9, %10;\n" "}\n" - : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), - "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), - "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), - "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), - "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), - "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), - "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), - "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), - "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), - "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), - "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), - "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), - "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), - "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), - "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), - "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), - "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), - "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71) + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3) : "l"(desc_a), "l"(desc_b), "r"(e), "n"(int32_t(spsel)), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x144x64_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x8x64_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x144x64 TN F32+=E5M2*E5M2 +// SPARSE GMMA 64x8x64 TN F32+=E5M2*E5M2 template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, GMMA::SparseSel spsel = GMMA::SparseSel::Zero > -struct GMMA_64x144x64_F32E5M2E5M2_RS_TN +struct GMMA_64x8x64_F32E5M2E5M2_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; using ERegisters = uint32_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = float[72]; + using CRegisters = float[4]; CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, uint64_t const& desc_b, - float & d00, float & d01, float & d02, float & d03, - float & d04, float & d05, float & d06, float & d07, - float & d08, float & d09, float & d10, float & d11, - float & d12, float & d13, float & d14, float & d15, - float & d16, float & d17, float & d18, float & d19, - float & d20, float & d21, float & d22, float & d23, - float & d24, float & d25, float & d26, float & d27, - float & d28, float & d29, float & d30, float & d31, - float & d32, float & d33, float & d34, float & d35, - float & d36, float & d37, float & d38, float & d39, - float & d40, float & d41, float & d42, float & d43, - float & d44, float & d45, float & d46, float & d47, - float & d48, float & d49, float & d50, float & d51, - float & d52, float & d53, float & d54, float & d55, - float & d56, float & d57, float & d58, float & d59, - float & d60, float & d61, float & d62, float & d63, - float & d64, float & d65, float & d66, float & d67, - float & d68, float & d69, float & d70, float & d71, + float & d0, float & d1, float & d2, float & d3, uint32_t const& e, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %79, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n144k64.f32.e5m2.e5m2 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71}," - "{%72, %73, %74, %75}," - " %76," - " %77, %78," - " p, %80, %81;\n" + "setp.ne.b32 p, %11, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n8k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + " %8," + " %9, %10," + " p, %12, %13;\n" "}\n" - : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), - "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), - "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), - "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), - "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), - "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), - "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), - "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), - "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), - "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), - "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), - "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), - "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), - "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), - "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), - "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), - "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), - "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), "l"(desc_b), "r"(e), "n"(int32_t(spsel)), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x144x64_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x8x64_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x160x64 TN F16+=E5M2*E5M2 +// SPARSE GMMA 64x16x64 TN F16+=E5M2*E5M2 template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, GMMA::SparseSel spsel = GMMA::SparseSel::Zero > -struct GMMA_64x160x64_F16E5M2E5M2_SS_TN +struct GMMA_64x16x64_F16E5M2E5M2_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; using ERegisters = uint32_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[40]; + using CRegisters = uint32_t[4]; CUTE_HOST_DEVICE static void fma(uint64_t const& desc_a, uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, uint32_t const& e, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %44, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n160k64.f16.e5m2.e5m2 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39}," - " %40," - " %41," - " %42, %43," - " p, %45, %46;\n" - "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) + "setp.ne.b32 p, %8, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n16k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3}," + " %4," + " %5," + " %6, %7," + " p, %9, %10;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) : "l"(desc_a), "l"(desc_b), "r"(e), "n"(int32_t(spsel)), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x160x64_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x16x64_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x160x64 TN F16+=E5M2*E5M2 +// SPARSE GMMA 64x16x64 TN F16+=E5M2*E5M2 template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, GMMA::SparseSel spsel = GMMA::SparseSel::Zero > -struct GMMA_64x160x64_F16E5M2E5M2_RS_TN +struct GMMA_64x16x64_F16E5M2E5M2_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; using ERegisters = uint32_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[40]; + using CRegisters = uint32_t[4]; CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, uint32_t const& e, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %47, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n160k64.f16.e5m2.e5m2 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39}," - "{%40, %41, %42, %43}," - " %44," - " %45, %46," - " p, %48, %49;\n" + "setp.ne.b32 p, %11, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n16k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + " %8," + " %9, %10," + " p, %12, %13;\n" "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), "l"(desc_b), "r"(e), "n"(int32_t(spsel)), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x160x64_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x16x64_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x160x64 TN F32+=E5M2*E5M2 +// SPARSE GMMA 64x16x64 TN F32+=E5M2*E5M2 template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, GMMA::SparseSel spsel = GMMA::SparseSel::Zero > -struct GMMA_64x160x64_F32E5M2E5M2_SS_TN +struct GMMA_64x16x64_F32E5M2E5M2_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; using ERegisters = uint32_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = float[80]; + using CRegisters = float[8]; CUTE_HOST_DEVICE static void fma(uint64_t const& desc_a, uint64_t const& desc_b, - float & d00, float & d01, float & d02, float & d03, - float & d04, float & d05, float & d06, float & d07, - float & d08, float & d09, float & d10, float & d11, - float & d12, float & d13, float & d14, float & d15, - float & d16, float & d17, float & d18, float & d19, - float & d20, float & d21, float & d22, float & d23, - float & d24, float & d25, float & d26, float & d27, - float & d28, float & d29, float & d30, float & d31, - float & d32, float & d33, float & d34, float & d35, - float & d36, float & d37, float & d38, float & d39, - float & d40, float & d41, float & d42, float & d43, - float & d44, float & d45, float & d46, float & d47, - float & d48, float & d49, float & d50, float & d51, - float & d52, float & d53, float & d54, float & d55, - float & d56, float & d57, float & d58, float & d59, - float & d60, float & d61, float & d62, float & d63, - float & d64, float & d65, float & d66, float & d67, - float & d68, float & d69, float & d70, float & d71, - float & d72, float & d73, float & d74, float & d75, - float & d76, float & d77, float & d78, float & d79, + float & d0, float & d1, float & d2, float & d3, + float & d4, float & d5, float & d6, float & d7, uint32_t const& e, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %84, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n160k64.f32.e5m2.e5m2 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79}," - " %80," - " %81," - " %82, %83," - " p, %85, %86;\n" + "setp.ne.b32 p, %12, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n16k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + " %8," + " %9," + " %10, %11," + " p, %13, %14;\n" "}\n" - : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), - "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), - "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), - "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), - "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), - "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), - "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), - "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), - "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), - "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), - "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), - "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), - "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), - "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), - "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), - "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), - "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), - "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), - "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), - "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79) + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3), + "+f"(d4), "+f"(d5), "+f"(d6), "+f"(d7) : "l"(desc_a), "l"(desc_b), "r"(e), "n"(int32_t(spsel)), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x160x64_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x16x64_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x160x64 TN F32+=E5M2*E5M2 +// SPARSE GMMA 64x16x64 TN F32+=E5M2*E5M2 template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, GMMA::SparseSel spsel = GMMA::SparseSel::Zero > -struct GMMA_64x160x64_F32E5M2E5M2_RS_TN +struct GMMA_64x16x64_F32E5M2E5M2_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; using ERegisters = uint32_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = float[80]; + using CRegisters = float[8]; CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, uint64_t const& desc_b, - float & d00, float & d01, float & d02, float & d03, - float & d04, float & d05, float & d06, float & d07, - float & d08, float & d09, float & d10, float & d11, - float & d12, float & d13, float & d14, float & d15, - float & d16, float & d17, float & d18, float & d19, - float & d20, float & d21, float & d22, float & d23, - float & d24, float & d25, float & d26, float & d27, - float & d28, float & d29, float & d30, float & d31, - float & d32, float & d33, float & d34, float & d35, - float & d36, float & d37, float & d38, float & d39, - float & d40, float & d41, float & d42, float & d43, - float & d44, float & d45, float & d46, float & d47, - float & d48, float & d49, float & d50, float & d51, - float & d52, float & d53, float & d54, float & d55, - float & d56, float & d57, float & d58, float & d59, - float & d60, float & d61, float & d62, float & d63, - float & d64, float & d65, float & d66, float & d67, - float & d68, float & d69, float & d70, float & d71, - float & d72, float & d73, float & d74, float & d75, - float & d76, float & d77, float & d78, float & d79, + float & d0, float & d1, float & d2, float & d3, + float & d4, float & d5, float & d6, float & d7, uint32_t const& e, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %87, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n160k64.f32.e5m2.e5m2 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79}," - "{%80, %81, %82, %83}," - " %84," - " %85, %86," - " p, %88, %89;\n" + "setp.ne.b32 p, %15, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n16k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + "{%8, %9, %10, %11}," + " %12," + " %13, %14," + " p, %16, %17;\n" "}\n" - : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), - "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), - "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), - "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), - "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), - "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), - "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), - "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), - "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), - "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), - "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), - "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), - "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), - "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), - "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), - "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), - "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), - "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), - "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), - "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3), + "+f"(d4), "+f"(d5), "+f"(d6), "+f"(d7) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), "l"(desc_b), "r"(e), "n"(int32_t(spsel)), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x160x64_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x16x64_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x176x64 TN F16+=E5M2*E5M2 +// SPARSE GMMA 64x32x64 TN F16+=E5M2*E5M2 template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, GMMA::SparseSel spsel = GMMA::SparseSel::Zero > -struct GMMA_64x176x64_F16E5M2E5M2_SS_TN +struct GMMA_64x32x64_F16E5M2E5M2_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; using ERegisters = uint32_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[44]; + using CRegisters = uint32_t[8]; CUTE_HOST_DEVICE static void fma(uint64_t const& desc_a, uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, uint32_t const& e, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %48, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n176k64.f16.e5m2.e5m2 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43}," - " %44," - " %45," - " %46, %47," - " p, %49, %50;\n" + "setp.ne.b32 p, %12, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n32k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + " %8," + " %9," + " %10, %11," + " p, %13, %14;\n" "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43) + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) : "l"(desc_a), "l"(desc_b), "r"(e), "n"(int32_t(spsel)), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x176x64_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x32x64_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x176x64 TN F16+=E5M2*E5M2 +// SPARSE GMMA 64x32x64 TN F16+=E5M2*E5M2 template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, GMMA::SparseSel spsel = GMMA::SparseSel::Zero > -struct GMMA_64x176x64_F16E5M2E5M2_RS_TN +struct GMMA_64x32x64_F16E5M2E5M2_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; using ERegisters = uint32_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[44]; + using CRegisters = uint32_t[8]; CUTE_HOST_DEVICE static void - fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, uint64_t const& desc_b, - uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, - uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, - uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, - uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, uint32_t const& e, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %51, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n176k64.f16.e5m2.e5m2 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43}," - "{%44, %45, %46, %47}," - " %48," - " %49, %50," - " p, %52, %53;\n" + "setp.ne.b32 p, %15, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n32k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + "{%8, %9, %10, %11}," + " %12," + " %13, %14," + " p, %16, %17;\n" "}\n" - : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), - "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), - "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43) - : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), "l"(desc_b), "r"(e), "n"(int32_t(spsel)), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x176x64_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x32x64_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x176x64 TN F32+=E5M2*E5M2 +// SPARSE GMMA 64x32x64 TN F32+=E5M2*E5M2 template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, GMMA::SparseSel spsel = GMMA::SparseSel::Zero > -struct GMMA_64x176x64_F32E5M2E5M2_SS_TN +struct GMMA_64x32x64_F32E5M2E5M2_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; using ERegisters = uint32_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = float[88]; + using CRegisters = float[16]; CUTE_HOST_DEVICE static void fma(uint64_t const& desc_a, @@ -51665,98 +21077,52 @@ struct GMMA_64x176x64_F32E5M2E5M2_SS_TN float & d04, float & d05, float & d06, float & d07, float & d08, float & d09, float & d10, float & d11, float & d12, float & d13, float & d14, float & d15, - float & d16, float & d17, float & d18, float & d19, - float & d20, float & d21, float & d22, float & d23, - float & d24, float & d25, float & d26, float & d27, - float & d28, float & d29, float & d30, float & d31, - float & d32, float & d33, float & d34, float & d35, - float & d36, float & d37, float & d38, float & d39, - float & d40, float & d41, float & d42, float & d43, - float & d44, float & d45, float & d46, float & d47, - float & d48, float & d49, float & d50, float & d51, - float & d52, float & d53, float & d54, float & d55, - float & d56, float & d57, float & d58, float & d59, - float & d60, float & d61, float & d62, float & d63, - float & d64, float & d65, float & d66, float & d67, - float & d68, float & d69, float & d70, float & d71, - float & d72, float & d73, float & d74, float & d75, - float & d76, float & d77, float & d78, float & d79, - float & d80, float & d81, float & d82, float & d83, - float & d84, float & d85, float & d86, float & d87, uint32_t const& e, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %92, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n176k64.f32.e5m2.e5m2 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87}," - " %88," - " %89," - " %90, %91," - " p, %93, %94;\n" + "setp.ne.b32 p, %20, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n32k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + " %16," + " %17," + " %18, %19," + " p, %21, %22;\n" "}\n" : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), - "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), - "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), - "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), - "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), - "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), - "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), - "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), - "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), - "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), - "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), - "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), - "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), - "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), - "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), - "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), - "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), - "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), - "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), - "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87) + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15) : "l"(desc_a), "l"(desc_b), "r"(e), "n"(int32_t(spsel)), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x176x64_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x32x64_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x176x64 TN F32+=E5M2*E5M2 +// SPARSE GMMA 64x32x64 TN F32+=E5M2*E5M2 template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, GMMA::SparseSel spsel = GMMA::SparseSel::Zero > -struct GMMA_64x176x64_F32E5M2E5M2_RS_TN +struct GMMA_64x32x64_F32E5M2E5M2_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; using ERegisters = uint32_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = float[88]; + using CRegisters = float[16]; CUTE_HOST_DEVICE static void fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, @@ -51765,97 +21131,52 @@ struct GMMA_64x176x64_F32E5M2E5M2_RS_TN float & d04, float & d05, float & d06, float & d07, float & d08, float & d09, float & d10, float & d11, float & d12, float & d13, float & d14, float & d15, - float & d16, float & d17, float & d18, float & d19, - float & d20, float & d21, float & d22, float & d23, - float & d24, float & d25, float & d26, float & d27, - float & d28, float & d29, float & d30, float & d31, - float & d32, float & d33, float & d34, float & d35, - float & d36, float & d37, float & d38, float & d39, - float & d40, float & d41, float & d42, float & d43, - float & d44, float & d45, float & d46, float & d47, - float & d48, float & d49, float & d50, float & d51, - float & d52, float & d53, float & d54, float & d55, - float & d56, float & d57, float & d58, float & d59, - float & d60, float & d61, float & d62, float & d63, - float & d64, float & d65, float & d66, float & d67, - float & d68, float & d69, float & d70, float & d71, - float & d72, float & d73, float & d74, float & d75, - float & d76, float & d77, float & d78, float & d79, - float & d80, float & d81, float & d82, float & d83, - float & d84, float & d85, float & d86, float & d87, uint32_t const& e, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %95, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n176k64.f32.e5m2.e5m2 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87}," - "{%88, %89, %90, %91}," - " %92," - " %93, %94," - " p, %96, %97;\n" + "setp.ne.b32 p, %23, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n32k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + "{%16, %17, %18, %19}," + " %20," + " %21, %22," + " p, %24, %25;\n" "}\n" : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), - "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), - "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), - "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), - "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), - "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), - "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), - "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), - "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), - "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), - "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), - "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), - "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), - "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), - "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), - "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), - "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), - "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), - "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), - "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87) + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15) : "r"(a00), "r"(a01), "r"(a02), "r"(a03), "l"(desc_b), "r"(e), "n"(int32_t(spsel)), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x176x64_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x32x64_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -// SPARSE GMMA 64x192x64 TN F16+=E5M2*E5M2 +// SPARSE GMMA 64x64x64 TN F16+=E5M2*E5M2 template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, GMMA::SparseSel spsel = GMMA::SparseSel::Zero > -struct GMMA_64x192x64_F16E5M2E5M2_SS_TN +struct GMMA_64x64x64_F16E5M2E5M2_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; using ERegisters = uint32_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[48]; + using CRegisters = uint32_t[16]; CUTE_HOST_DEVICE static void fma(uint64_t const& desc_a, @@ -51864,71 +21185,52 @@ struct GMMA_64x192x64_F16E5M2E5M2_SS_TN uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, uint32_t const& e, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %52, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n192k64.f16.e5m2.e5m2 " + "setp.ne.b32 p, %20, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n64k64.f16.e5m2.e5m2 " "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47}," - " %48," - " %49," - " %50, %51," - " p, %53, %54;\n" + " %8, %9, %10, %11, %12, %13, %14, %15}," + " %16," + " %17," + " %18, %19," + " p, %21, %22;\n" "}\n" : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), - "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) : "l"(desc_a), "l"(desc_b), "r"(e), "n"(int32_t(spsel)), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x192x64_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x64x64_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// SPARSE GMMA 64x192x64 TN F16+=E5M2*E5M2 +// SPARSE GMMA 64x64x64 TN F16+=E5M2*E5M2 template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, GMMA::SparseSel spsel = GMMA::SparseSel::Zero > -struct GMMA_64x192x64_F16E5M2E5M2_RS_TN +struct GMMA_64x64x64_F16E5M2E5M2_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; using ERegisters = uint32_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[48]; + using CRegisters = uint32_t[16]; CUTE_HOST_DEVICE static void fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, @@ -51937,71 +21239,52 @@ struct GMMA_64x192x64_F16E5M2E5M2_RS_TN uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, - uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, - uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, uint32_t const& e, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %55, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n192k64.f16.e5m2.e5m2 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47}," - "{%48, %49, %50, %51}," - " %52," - " %53, %54," - " p, %56, %57;\n" + "setp.ne.b32 p, %23, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n64k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + "{%16, %17, %18, %19}," + " %20," + " %21, %22," + " p, %24, %25;\n" "}\n" : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), - "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), - "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), - "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) : "r"(a00), "r"(a01), "r"(a02), "r"(a03), "l"(desc_b), "r"(e), "n"(int32_t(spsel)), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x192x64_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x64x64_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// SPARSE GMMA 64x192x64 TN F32+=E5M2*E5M2 +// SPARSE GMMA 64x64x64 TN F32+=E5M2*E5M2 template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, GMMA::SparseSel spsel = GMMA::SparseSel::Zero > -struct GMMA_64x192x64_F32E5M2E5M2_SS_TN +struct GMMA_64x64x64_F32E5M2E5M2_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; using ERegisters = uint32_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = float[96]; + using CRegisters = float[32]; CUTE_HOST_DEVICE static void fma(uint64_t const& desc_a, @@ -52010,51 +21293,28 @@ struct GMMA_64x192x64_F32E5M2E5M2_SS_TN float & d04, float & d05, float & d06, float & d07, float & d08, float & d09, float & d10, float & d11, float & d12, float & d13, float & d14, float & d15, - float & d16, float & d17, float & d18, float & d19, - float & d20, float & d21, float & d22, float & d23, - float & d24, float & d25, float & d26, float & d27, - float & d28, float & d29, float & d30, float & d31, - float & d32, float & d33, float & d34, float & d35, - float & d36, float & d37, float & d38, float & d39, - float & d40, float & d41, float & d42, float & d43, - float & d44, float & d45, float & d46, float & d47, - float & d48, float & d49, float & d50, float & d51, - float & d52, float & d53, float & d54, float & d55, - float & d56, float & d57, float & d58, float & d59, - float & d60, float & d61, float & d62, float & d63, - float & d64, float & d65, float & d66, float & d67, - float & d68, float & d69, float & d70, float & d71, - float & d72, float & d73, float & d74, float & d75, - float & d76, float & d77, float & d78, float & d79, - float & d80, float & d81, float & d82, float & d83, - float & d84, float & d85, float & d86, float & d87, - float & d88, float & d89, float & d90, float & d91, - float & d92, float & d93, float & d94, float & d95, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, uint32_t const& e, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %100, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n192k64.f32.e5m2.e5m2 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87, " - " %88, %89, %90, %91, %92, %93, %94, %95}," - " %96," - " %97," - " %98, %99," - " p, %101, %102;\n" + "setp.ne.b32 p, %36, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n64k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + " %32," + " %33," + " %34, %35," + " p, %37, %38;\n" "}\n" : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), @@ -52063,48 +21323,32 @@ struct GMMA_64x192x64_F32E5M2E5M2_SS_TN "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), - "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), - "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), - "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), - "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), - "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), - "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), - "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), - "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), - "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), - "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), - "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), - "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), - "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), - "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), - "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), - "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91), - "+f"(d92), "+f"(d93), "+f"(d94), "+f"(d95) + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31) : "l"(desc_a), "l"(desc_b), "r"(e), "n"(int32_t(spsel)), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x192x64_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x64x64_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// SPARSE GMMA 64x192x64 TN F32+=E5M2*E5M2 +// SPARSE GMMA 64x64x64 TN F32+=E5M2*E5M2 template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, GMMA::SparseSel spsel = GMMA::SparseSel::Zero > -struct GMMA_64x192x64_F32E5M2E5M2_RS_TN +struct GMMA_64x64x64_F32E5M2E5M2_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; using ERegisters = uint32_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = float[96]; + using CRegisters = float[32]; CUTE_HOST_DEVICE static void fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, @@ -52117,47 +21361,24 @@ struct GMMA_64x192x64_F32E5M2E5M2_RS_TN float & d20, float & d21, float & d22, float & d23, float & d24, float & d25, float & d26, float & d27, float & d28, float & d29, float & d30, float & d31, - float & d32, float & d33, float & d34, float & d35, - float & d36, float & d37, float & d38, float & d39, - float & d40, float & d41, float & d42, float & d43, - float & d44, float & d45, float & d46, float & d47, - float & d48, float & d49, float & d50, float & d51, - float & d52, float & d53, float & d54, float & d55, - float & d56, float & d57, float & d58, float & d59, - float & d60, float & d61, float & d62, float & d63, - float & d64, float & d65, float & d66, float & d67, - float & d68, float & d69, float & d70, float & d71, - float & d72, float & d73, float & d74, float & d75, - float & d76, float & d77, float & d78, float & d79, - float & d80, float & d81, float & d82, float & d83, - float & d84, float & d85, float & d86, float & d87, - float & d88, float & d89, float & d90, float & d91, - float & d92, float & d93, float & d94, float & d95, uint32_t const& e, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %103, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n192k64.f32.e5m2.e5m2 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87, " - " %88, %89, %90, %91, %92, %93, %94, %95}," - "{%96, %97, %98, %99}," - " %100," - " %101, %102," - " p, %104, %105;\n" + "setp.ne.b32 p, %39, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n64k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + "{%32, %33, %34, %35}," + " %36," + " %37, %38," + " p, %40, %41;\n" "}\n" : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), @@ -52166,49 +21387,32 @@ struct GMMA_64x192x64_F32E5M2E5M2_RS_TN "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), - "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), - "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), - "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), - "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), - "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), - "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), - "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), - "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), - "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), - "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), - "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), - "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), - "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), - "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), - "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), - "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91), - "+f"(d92), "+f"(d93), "+f"(d94), "+f"(d95) + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31) : "r"(a00), "r"(a01), "r"(a02), "r"(a03), "l"(desc_b), "r"(e), "n"(int32_t(spsel)), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x192x64_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x64x64_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x208x64 TN F16+=E5M2*E5M2 +// SPARSE GMMA 64x96x64 TN F16+=E5M2*E5M2 template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, GMMA::SparseSel spsel = GMMA::SparseSel::Zero > -struct GMMA_64x208x64_F16E5M2E5M2_SS_TN +struct GMMA_64x96x64_F16E5M2E5M2_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; using ERegisters = uint32_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[52]; + using CRegisters = uint32_t[24]; CUTE_HOST_DEVICE static void fma(uint64_t const& desc_a, @@ -52219,74 +21423,55 @@ struct GMMA_64x208x64_F16E5M2E5M2_SS_TN uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, - uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, uint32_t const& e, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %56, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n208k64.f16.e5m2.e5m2 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51}," - " %52," - " %53," - " %54, %55," - " p, %57, %58;\n" + "setp.ne.b32 p, %28, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n96k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + " %24," + " %25," + " %26, %27," + " p, %29, %30;\n" "}\n" : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), - "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), - "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51) + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) : "l"(desc_a), "l"(desc_b), "r"(e), "n"(int32_t(spsel)), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x208x64_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x96x64_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x208x64 TN F16+=E5M2*E5M2 +// SPARSE GMMA 64x96x64 TN F16+=E5M2*E5M2 template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, GMMA::SparseSel spsel = GMMA::SparseSel::Zero > -struct GMMA_64x208x64_F16E5M2E5M2_RS_TN +struct GMMA_64x96x64_F16E5M2E5M2_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; using ERegisters = uint32_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[52]; + using CRegisters = uint32_t[24]; CUTE_HOST_DEVICE static void fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, @@ -52297,294 +21482,203 @@ struct GMMA_64x208x64_F16E5M2E5M2_RS_TN uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, - uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, - uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, - uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, uint32_t const& e, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %59, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n208k64.f16.e5m2.e5m2 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51}," - "{%52, %53, %54, %55}," - " %56," - " %57, %58," - " p, %60, %61;\n" + "setp.ne.b32 p, %31, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n96k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + "{%24, %25, %26, %27}," + " %28," + " %29, %30," + " p, %32, %33;\n" "}\n" : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), - "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), - "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), - "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), - "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51) + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) : "r"(a00), "r"(a01), "r"(a02), "r"(a03), "l"(desc_b), "r"(e), "n"(int32_t(spsel)), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x208x64_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x96x64_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x208x64 TN F32+=E5M2*E5M2 +// SPARSE GMMA 64x96x64 TN F32+=E5M2*E5M2 template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, GMMA::SparseSel spsel = GMMA::SparseSel::Zero > -struct GMMA_64x208x64_F32E5M2E5M2_SS_TN +struct GMMA_64x96x64_F32E5M2E5M2_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; using ERegisters = uint32_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = float[104]; + using CRegisters = float[48]; CUTE_HOST_DEVICE static void fma(uint64_t const& desc_a, uint64_t const& desc_b, - float & d000, float & d001, float & d002, float & d003, - float & d004, float & d005, float & d006, float & d007, - float & d008, float & d009, float & d010, float & d011, - float & d012, float & d013, float & d014, float & d015, - float & d016, float & d017, float & d018, float & d019, - float & d020, float & d021, float & d022, float & d023, - float & d024, float & d025, float & d026, float & d027, - float & d028, float & d029, float & d030, float & d031, - float & d032, float & d033, float & d034, float & d035, - float & d036, float & d037, float & d038, float & d039, - float & d040, float & d041, float & d042, float & d043, - float & d044, float & d045, float & d046, float & d047, - float & d048, float & d049, float & d050, float & d051, - float & d052, float & d053, float & d054, float & d055, - float & d056, float & d057, float & d058, float & d059, - float & d060, float & d061, float & d062, float & d063, - float & d064, float & d065, float & d066, float & d067, - float & d068, float & d069, float & d070, float & d071, - float & d072, float & d073, float & d074, float & d075, - float & d076, float & d077, float & d078, float & d079, - float & d080, float & d081, float & d082, float & d083, - float & d084, float & d085, float & d086, float & d087, - float & d088, float & d089, float & d090, float & d091, - float & d092, float & d093, float & d094, float & d095, - float & d096, float & d097, float & d098, float & d099, - float & d100, float & d101, float & d102, float & d103, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, uint32_t const& e, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %108, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n208k64.f32.e5m2.e5m2 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87, " - " %88, %89, %90, %91, %92, %93, %94, %95, " - " %96, %97, %98, %99, %100, %101, %102, %103}," - " %104," - " %105," - " %106, %107," - " p, %109, %110;\n" + "setp.ne.b32 p, %52, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n96k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + " %48," + " %49," + " %50, %51," + " p, %53, %54;\n" "}\n" - : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), - "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), - "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), - "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), - "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), - "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), - "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), - "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), - "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), - "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), - "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), - "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), - "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), - "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), - "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), - "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), - "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), - "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), - "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), - "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), - "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), - "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), - "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), - "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), - "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), - "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103) + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47) : "l"(desc_a), "l"(desc_b), "r"(e), "n"(int32_t(spsel)), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x208x64_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x96x64_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x208x64 TN F32+=E5M2*E5M2 +// SPARSE GMMA 64x96x64 TN F32+=E5M2*E5M2 template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, GMMA::SparseSel spsel = GMMA::SparseSel::Zero > -struct GMMA_64x208x64_F32E5M2E5M2_RS_TN +struct GMMA_64x96x64_F32E5M2E5M2_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; using ERegisters = uint32_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = float[104]; + using CRegisters = float[48]; CUTE_HOST_DEVICE static void - fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, uint64_t const& desc_b, - float & d000, float & d001, float & d002, float & d003, - float & d004, float & d005, float & d006, float & d007, - float & d008, float & d009, float & d010, float & d011, - float & d012, float & d013, float & d014, float & d015, - float & d016, float & d017, float & d018, float & d019, - float & d020, float & d021, float & d022, float & d023, - float & d024, float & d025, float & d026, float & d027, - float & d028, float & d029, float & d030, float & d031, - float & d032, float & d033, float & d034, float & d035, - float & d036, float & d037, float & d038, float & d039, - float & d040, float & d041, float & d042, float & d043, - float & d044, float & d045, float & d046, float & d047, - float & d048, float & d049, float & d050, float & d051, - float & d052, float & d053, float & d054, float & d055, - float & d056, float & d057, float & d058, float & d059, - float & d060, float & d061, float & d062, float & d063, - float & d064, float & d065, float & d066, float & d067, - float & d068, float & d069, float & d070, float & d071, - float & d072, float & d073, float & d074, float & d075, - float & d076, float & d077, float & d078, float & d079, - float & d080, float & d081, float & d082, float & d083, - float & d084, float & d085, float & d086, float & d087, - float & d088, float & d089, float & d090, float & d091, - float & d092, float & d093, float & d094, float & d095, - float & d096, float & d097, float & d098, float & d099, - float & d100, float & d101, float & d102, float & d103, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, uint32_t const& e, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %111, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n208k64.f32.e5m2.e5m2 " + "setp.ne.b32 p, %55, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n96k64.f32.e5m2.e5m2 " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " " %16, %17, %18, %19, %20, %21, %22, %23, " " %24, %25, %26, %27, %28, %29, %30, %31, " " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87, " - " %88, %89, %90, %91, %92, %93, %94, %95, " - " %96, %97, %98, %99, %100, %101, %102, %103}," - "{%104, %105, %106, %107}," - " %108," - " %109, %110," - " p, %112, %113;\n" + " %40, %41, %42, %43, %44, %45, %46, %47}," + "{%48, %49, %50, %51}," + " %52," + " %53, %54," + " p, %56, %57;\n" "}\n" - : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), - "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), - "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), - "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), - "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), - "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), - "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), - "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), - "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), - "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), - "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), - "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), - "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), - "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), - "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), - "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), - "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), - "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), - "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), - "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), - "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), - "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), - "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), - "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), - "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), - "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103) - : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), "l"(desc_b), "r"(e), "n"(int32_t(spsel)), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x208x64_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x96x64_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x224x64 TN F16+=E5M2*E5M2 +// SPARSE GMMA 64x128x64 TN F16+=E5M2*E5M2 template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, GMMA::SparseSel spsel = GMMA::SparseSel::Zero > -struct GMMA_64x224x64_F16E5M2E5M2_SS_TN +struct GMMA_64x128x64_F16E5M2E5M2_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; using ERegisters = uint32_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[56]; + using CRegisters = uint32_t[32]; CUTE_HOST_DEVICE static void fma(uint64_t const& desc_a, @@ -52597,32 +21691,24 @@ struct GMMA_64x224x64_F16E5M2E5M2_SS_TN uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, - uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, - uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, uint32_t const& e, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %60, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n224k64.f16.e5m2.e5m2 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55}," - " %56," - " %57," - " %58, %59," - " p, %61, %62;\n" + "setp.ne.b32 p, %36, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n128k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + " %32," + " %33," + " %34, %35," + " p, %37, %38;\n" "}\n" : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), @@ -52631,40 +21717,32 @@ struct GMMA_64x224x64_F16E5M2E5M2_SS_TN "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), - "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), - "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), - "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) : "l"(desc_a), "l"(desc_b), "r"(e), "n"(int32_t(spsel)), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x224x64_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x128x64_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x224x64 TN F16+=E5M2*E5M2 +// SPARSE GMMA 64x128x64 TN F16+=E5M2*E5M2 template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, GMMA::SparseSel spsel = GMMA::SparseSel::Zero > -struct GMMA_64x224x64_F16E5M2E5M2_RS_TN +struct GMMA_64x128x64_F16E5M2E5M2_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; using ERegisters = uint32_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[56]; + using CRegisters = uint32_t[32]; CUTE_HOST_DEVICE static void fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, @@ -52677,32 +21755,24 @@ struct GMMA_64x224x64_F16E5M2E5M2_RS_TN uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, - uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, - uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, - uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, - uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, - uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, - uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, uint32_t const& e, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %63, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n224k64.f16.e5m2.e5m2 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55}," - "{%56, %57, %58, %59}," - " %60," - " %61, %62," - " p, %64, %65;\n" + "setp.ne.b32 p, %39, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n128k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + "{%32, %33, %34, %35}," + " %36," + " %37, %38," + " p, %40, %41;\n" "}\n" : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), @@ -52711,81 +21781,62 @@ struct GMMA_64x224x64_F16E5M2E5M2_RS_TN "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), - "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), - "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), - "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), - "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), - "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), - "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), - "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) : "r"(a00), "r"(a01), "r"(a02), "r"(a03), "l"(desc_b), "r"(e), "n"(int32_t(spsel)), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x224x64_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x128x64_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x224x64 TN F32+=E5M2*E5M2 +// SPARSE GMMA 64x128x64 TN F32+=E5M2*E5M2 template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, - GMMA::SparseSel spsel = GMMA::SparseSel::Zero -> -struct GMMA_64x224x64_F32E5M2E5M2_SS_TN -{ - using DRegisters = void; - using ARegisters = uint64_t[1]; - using ERegisters = uint32_t[1]; - using BRegisters = uint64_t[1]; - using CRegisters = float[112]; - - CUTE_HOST_DEVICE static void - fma(uint64_t const& desc_a, - uint64_t const& desc_b, - float & d000, float & d001, float & d002, float & d003, - float & d004, float & d005, float & d006, float & d007, - float & d008, float & d009, float & d010, float & d011, - float & d012, float & d013, float & d014, float & d015, - float & d016, float & d017, float & d018, float & d019, - float & d020, float & d021, float & d022, float & d023, - float & d024, float & d025, float & d026, float & d027, - float & d028, float & d029, float & d030, float & d031, - float & d032, float & d033, float & d034, float & d035, - float & d036, float & d037, float & d038, float & d039, - float & d040, float & d041, float & d042, float & d043, - float & d044, float & d045, float & d046, float & d047, - float & d048, float & d049, float & d050, float & d051, - float & d052, float & d053, float & d054, float & d055, - float & d056, float & d057, float & d058, float & d059, - float & d060, float & d061, float & d062, float & d063, - float & d064, float & d065, float & d066, float & d067, - float & d068, float & d069, float & d070, float & d071, - float & d072, float & d073, float & d074, float & d075, - float & d076, float & d077, float & d078, float & d079, - float & d080, float & d081, float & d082, float & d083, - float & d084, float & d085, float & d086, float & d087, - float & d088, float & d089, float & d090, float & d091, - float & d092, float & d093, float & d094, float & d095, - float & d096, float & d097, float & d098, float & d099, - float & d100, float & d101, float & d102, float & d103, - float & d104, float & d105, float & d106, float & d107, - float & d108, float & d109, float & d110, float & d111, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x128x64_F32E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[64]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, uint32_t const& e, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %116, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n224k64.f32.e5m2.e5m2 " + "setp.ne.b32 p, %68, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n128k64.f32.e5m2.e5m2 " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " " %16, %17, %18, %19, %20, %21, %22, %23, " @@ -52793,114 +21844,83 @@ struct GMMA_64x224x64_F32E5M2E5M2_SS_TN " %32, %33, %34, %35, %36, %37, %38, %39, " " %40, %41, %42, %43, %44, %45, %46, %47, " " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87, " - " %88, %89, %90, %91, %92, %93, %94, %95, " - " %96, %97, %98, %99, %100, %101, %102, %103, " - " %104, %105, %106, %107, %108, %109, %110, %111}," - " %112," - " %113," - " %114, %115," - " p, %117, %118;\n" + " %56, %57, %58, %59, %60, %61, %62, %63}," + " %64," + " %65," + " %66, %67," + " p, %69, %70;\n" "}\n" - : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), - "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), - "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), - "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), - "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), - "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), - "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), - "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), - "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), - "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), - "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), - "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), - "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), - "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), - "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), - "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), - "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), - "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), - "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), - "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), - "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), - "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), - "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), - "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), - "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), - "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), - "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), - "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111) + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63) : "l"(desc_a), "l"(desc_b), "r"(e), "n"(int32_t(spsel)), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x224x64_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x128x64_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x224x64 TN F32+=E5M2*E5M2 +// SPARSE GMMA 64x128x64 TN F32+=E5M2*E5M2 template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, GMMA::SparseSel spsel = GMMA::SparseSel::Zero > -struct GMMA_64x224x64_F32E5M2E5M2_RS_TN +struct GMMA_64x128x64_F32E5M2E5M2_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; using ERegisters = uint32_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = float[112]; + using CRegisters = float[64]; CUTE_HOST_DEVICE static void - fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, uint64_t const& desc_b, - float & d000, float & d001, float & d002, float & d003, - float & d004, float & d005, float & d006, float & d007, - float & d008, float & d009, float & d010, float & d011, - float & d012, float & d013, float & d014, float & d015, - float & d016, float & d017, float & d018, float & d019, - float & d020, float & d021, float & d022, float & d023, - float & d024, float & d025, float & d026, float & d027, - float & d028, float & d029, float & d030, float & d031, - float & d032, float & d033, float & d034, float & d035, - float & d036, float & d037, float & d038, float & d039, - float & d040, float & d041, float & d042, float & d043, - float & d044, float & d045, float & d046, float & d047, - float & d048, float & d049, float & d050, float & d051, - float & d052, float & d053, float & d054, float & d055, - float & d056, float & d057, float & d058, float & d059, - float & d060, float & d061, float & d062, float & d063, - float & d064, float & d065, float & d066, float & d067, - float & d068, float & d069, float & d070, float & d071, - float & d072, float & d073, float & d074, float & d075, - float & d076, float & d077, float & d078, float & d079, - float & d080, float & d081, float & d082, float & d083, - float & d084, float & d085, float & d086, float & d087, - float & d088, float & d089, float & d090, float & d091, - float & d092, float & d093, float & d094, float & d095, - float & d096, float & d097, float & d098, float & d099, - float & d100, float & d101, float & d102, float & d103, - float & d104, float & d105, float & d106, float & d107, - float & d108, float & d109, float & d110, float & d111, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, uint32_t const& e, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %119, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n224k64.f32.e5m2.e5m2 " + "setp.ne.b32 p, %71, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n128k64.f32.e5m2.e5m2 " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " " %16, %17, %18, %19, %20, %21, %22, %23, " @@ -52908,73 +21928,53 @@ struct GMMA_64x224x64_F32E5M2E5M2_RS_TN " %32, %33, %34, %35, %36, %37, %38, %39, " " %40, %41, %42, %43, %44, %45, %46, %47, " " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87, " - " %88, %89, %90, %91, %92, %93, %94, %95, " - " %96, %97, %98, %99, %100, %101, %102, %103, " - " %104, %105, %106, %107, %108, %109, %110, %111}," - "{%112, %113, %114, %115}," - " %116," - " %117, %118," - " p, %120, %121;\n" + " %56, %57, %58, %59, %60, %61, %62, %63}," + "{%64, %65, %66, %67}," + " %68," + " %69, %70," + " p, %72, %73;\n" "}\n" - : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), - "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), - "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), - "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), - "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), - "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), - "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), - "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), - "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), - "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), - "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), - "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), - "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), - "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), - "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), - "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), - "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), - "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), - "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), - "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), - "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), - "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), - "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), - "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), - "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), - "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), - "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), - "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111) - : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), "l"(desc_b), "r"(e), "n"(int32_t(spsel)), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x224x64_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x128x64_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x240x64 TN F16+=E5M2*E5M2 +// SPARSE GMMA 64x192x64 TN F16+=E5M2*E5M2 template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, GMMA::SparseSel spsel = GMMA::SparseSel::Zero > -struct GMMA_64x240x64_F16E5M2E5M2_SS_TN +struct GMMA_64x192x64_F16E5M2E5M2_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; using ERegisters = uint32_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[60]; + using CRegisters = uint32_t[48]; CUTE_HOST_DEVICE static void fma(uint64_t const& desc_a, @@ -52991,30 +21991,26 @@ struct GMMA_64x240x64_F16E5M2E5M2_SS_TN uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, - uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, - uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, - uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, uint32_t const& e, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %64, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n240k64.f16.e5m2.e5m2 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59}," - " %60," - " %61," - " %62, %63," - " p, %65, %66;\n" + "setp.ne.b32 p, %52, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n192k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + " %48," + " %49," + " %50, %51," + " p, %53, %54;\n" "}\n" : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), @@ -53027,37 +22023,32 @@ struct GMMA_64x240x64_F16E5M2E5M2_SS_TN "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), - "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), - "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), - "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), - "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59) + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) : "l"(desc_a), "l"(desc_b), "r"(e), "n"(int32_t(spsel)), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x240x64_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x192x64_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x240x64 TN F16+=E5M2*E5M2 +// SPARSE GMMA 64x192x64 TN F16+=E5M2*E5M2 template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, GMMA::SparseSel spsel = GMMA::SparseSel::Zero > -struct GMMA_64x240x64_F16E5M2E5M2_RS_TN +struct GMMA_64x192x64_F16E5M2E5M2_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; using ERegisters = uint32_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = uint32_t[60]; + using CRegisters = uint32_t[48]; CUTE_HOST_DEVICE static void fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, @@ -53074,30 +22065,26 @@ struct GMMA_64x240x64_F16E5M2E5M2_RS_TN uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, - uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, - uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, - uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, uint32_t const& e, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %67, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n240k64.f16.e5m2.e5m2 " + "setp.ne.b32 p, %55, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n192k64.f16.e5m2.e5m2 " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " " %16, %17, %18, %19, %20, %21, %22, %23, " " %24, %25, %26, %27, %28, %29, %30, %31, " " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59}," - "{%60, %61, %62, %63}," - " %64," - " %65, %66," - " p, %68, %69;\n" + " %40, %41, %42, %43, %44, %45, %46, %47}," + "{%48, %49, %50, %51}," + " %52," + " %53, %54," + " p, %56, %57;\n" "}\n" : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), @@ -53110,80 +22097,70 @@ struct GMMA_64x240x64_F16E5M2E5M2_RS_TN "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), - "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), - "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), - "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), - "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59) + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) : "r"(a00), "r"(a01), "r"(a02), "r"(a03), "l"(desc_b), "r"(e), "n"(int32_t(spsel)), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x240x64_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x192x64_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x240x64 TN F32+=E5M2*E5M2 +// SPARSE GMMA 64x192x64 TN F32+=E5M2*E5M2 template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, GMMA::SparseSel spsel = GMMA::SparseSel::Zero > -struct GMMA_64x240x64_F32E5M2E5M2_SS_TN +struct GMMA_64x192x64_F32E5M2E5M2_SS_TN { using DRegisters = void; using ARegisters = uint64_t[1]; using ERegisters = uint32_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = float[120]; + using CRegisters = float[96]; CUTE_HOST_DEVICE static void fma(uint64_t const& desc_a, uint64_t const& desc_b, - float & d000, float & d001, float & d002, float & d003, - float & d004, float & d005, float & d006, float & d007, - float & d008, float & d009, float & d010, float & d011, - float & d012, float & d013, float & d014, float & d015, - float & d016, float & d017, float & d018, float & d019, - float & d020, float & d021, float & d022, float & d023, - float & d024, float & d025, float & d026, float & d027, - float & d028, float & d029, float & d030, float & d031, - float & d032, float & d033, float & d034, float & d035, - float & d036, float & d037, float & d038, float & d039, - float & d040, float & d041, float & d042, float & d043, - float & d044, float & d045, float & d046, float & d047, - float & d048, float & d049, float & d050, float & d051, - float & d052, float & d053, float & d054, float & d055, - float & d056, float & d057, float & d058, float & d059, - float & d060, float & d061, float & d062, float & d063, - float & d064, float & d065, float & d066, float & d067, - float & d068, float & d069, float & d070, float & d071, - float & d072, float & d073, float & d074, float & d075, - float & d076, float & d077, float & d078, float & d079, - float & d080, float & d081, float & d082, float & d083, - float & d084, float & d085, float & d086, float & d087, - float & d088, float & d089, float & d090, float & d091, - float & d092, float & d093, float & d094, float & d095, - float & d096, float & d097, float & d098, float & d099, - float & d100, float & d101, float & d102, float & d103, - float & d104, float & d105, float & d106, float & d107, - float & d108, float & d109, float & d110, float & d111, - float & d112, float & d113, float & d114, float & d115, - float & d116, float & d117, float & d118, float & d119, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + float & d88, float & d89, float & d90, float & d91, + float & d92, float & d93, float & d94, float & d95, uint32_t const& e, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %124, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n240k64.f32.e5m2.e5m2 " + "setp.ne.b32 p, %100, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n192k64.f32.e5m2.e5m2 " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " " %16, %17, %18, %19, %20, %21, %22, %23, " @@ -53195,115 +22172,99 @@ struct GMMA_64x240x64_F32E5M2E5M2_SS_TN " %64, %65, %66, %67, %68, %69, %70, %71, " " %72, %73, %74, %75, %76, %77, %78, %79, " " %80, %81, %82, %83, %84, %85, %86, %87, " - " %88, %89, %90, %91, %92, %93, %94, %95, " - " %96, %97, %98, %99, %100, %101, %102, %103, " - " %104, %105, %106, %107, %108, %109, %110, %111, " - " %112, %113, %114, %115, %116, %117, %118, %119}," - " %120," - " %121," - " %122, %123," - " p, %125, %126;\n" + " %88, %89, %90, %91, %92, %93, %94, %95}," + " %96," + " %97," + " %98, %99," + " p, %101, %102;\n" "}\n" - : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), - "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), - "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), - "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), - "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), - "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), - "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), - "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), - "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), - "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), - "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), - "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), - "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), - "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), - "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), - "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), - "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), - "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), - "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), - "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), - "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), - "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), - "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), - "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), - "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), - "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), - "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), - "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), - "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), - "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119) + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), + "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91), + "+f"(d92), "+f"(d93), "+f"(d94), "+f"(d95) : "l"(desc_a), "l"(desc_b), "r"(e), "n"(int32_t(spsel)), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x240x64_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x192x64_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -// SPARSE GMMA 64x240x64 TN F32+=E5M2*E5M2 +// SPARSE GMMA 64x192x64 TN F32+=E5M2*E5M2 template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, GMMA::SparseSel spsel = GMMA::SparseSel::Zero > -struct GMMA_64x240x64_F32E5M2E5M2_RS_TN +struct GMMA_64x192x64_F32E5M2E5M2_RS_TN { using DRegisters = void; using ARegisters = uint32_t[4]; using ERegisters = uint32_t[1]; using BRegisters = uint64_t[1]; - using CRegisters = float[120]; + using CRegisters = float[96]; CUTE_HOST_DEVICE static void - fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, uint64_t const& desc_b, - float & d000, float & d001, float & d002, float & d003, - float & d004, float & d005, float & d006, float & d007, - float & d008, float & d009, float & d010, float & d011, - float & d012, float & d013, float & d014, float & d015, - float & d016, float & d017, float & d018, float & d019, - float & d020, float & d021, float & d022, float & d023, - float & d024, float & d025, float & d026, float & d027, - float & d028, float & d029, float & d030, float & d031, - float & d032, float & d033, float & d034, float & d035, - float & d036, float & d037, float & d038, float & d039, - float & d040, float & d041, float & d042, float & d043, - float & d044, float & d045, float & d046, float & d047, - float & d048, float & d049, float & d050, float & d051, - float & d052, float & d053, float & d054, float & d055, - float & d056, float & d057, float & d058, float & d059, - float & d060, float & d061, float & d062, float & d063, - float & d064, float & d065, float & d066, float & d067, - float & d068, float & d069, float & d070, float & d071, - float & d072, float & d073, float & d074, float & d075, - float & d076, float & d077, float & d078, float & d079, - float & d080, float & d081, float & d082, float & d083, - float & d084, float & d085, float & d086, float & d087, - float & d088, float & d089, float & d090, float & d091, - float & d092, float & d093, float & d094, float & d095, - float & d096, float & d097, float & d098, float & d099, - float & d100, float & d101, float & d102, float & d103, - float & d104, float & d105, float & d106, float & d107, - float & d108, float & d109, float & d110, float & d111, - float & d112, float & d113, float & d114, float & d115, - float & d116, float & d117, float & d118, float & d119, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + float & d88, float & d89, float & d90, float & d91, + float & d92, float & d93, float & d94, float & d95, uint32_t const& e, GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" - "setp.ne.b32 p, %127, 0;\n" - "wgmma.mma_async.sp.sync.aligned.m64n240k64.f32.e5m2.e5m2 " + "setp.ne.b32 p, %103, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n192k64.f32.e5m2.e5m2 " "{%0, %1, %2, %3, %4, %5, %6, %7, " " %8, %9, %10, %11, %12, %13, %14, %15, " " %16, %17, %18, %19, %20, %21, %22, %23, " @@ -53315,55 +22276,45 @@ struct GMMA_64x240x64_F32E5M2E5M2_RS_TN " %64, %65, %66, %67, %68, %69, %70, %71, " " %72, %73, %74, %75, %76, %77, %78, %79, " " %80, %81, %82, %83, %84, %85, %86, %87, " - " %88, %89, %90, %91, %92, %93, %94, %95, " - " %96, %97, %98, %99, %100, %101, %102, %103, " - " %104, %105, %106, %107, %108, %109, %110, %111, " - " %112, %113, %114, %115, %116, %117, %118, %119}," - "{%120, %121, %122, %123}," - " %124," - " %125, %126," - " p, %128, %129;\n" + " %88, %89, %90, %91, %92, %93, %94, %95}," + "{%96, %97, %98, %99}," + " %100," + " %101, %102," + " p, %104, %105;\n" "}\n" - : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), - "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), - "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), - "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), - "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), - "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), - "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), - "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), - "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), - "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), - "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), - "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), - "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), - "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), - "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), - "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), - "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), - "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), - "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), - "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), - "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), - "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), - "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), - "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), - "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), - "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), - "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), - "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), - "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), - "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119) - : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), + "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91), + "+f"(d92), "+f"(d93), "+f"(d94), "+f"(d95) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), "l"(desc_b), "r"(e), "n"(int32_t(spsel)), "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x240x64_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x192x64_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); #endif } }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -53404,6 +22355,7 @@ struct GMMA_64x256x64_F16E5M2E5M2_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -53487,6 +22439,7 @@ struct GMMA_64x256x64_F16E5M2E5M2_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -53586,6 +22539,7 @@ struct GMMA_64x256x64_F32E5M2E5M2_SS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -53709,6 +22663,7 @@ struct GMMA_64x256x64_F32E5M2E5M2_RS_TN GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) { #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); asm volatile( "{\n" ".reg .pred p;\n" @@ -53779,11 +22734,10 @@ struct GMMA_64x256x64_F32E5M2E5M2_RS_TN //////////////////////////////////////////////////////////////////////////////////////////////////// -//////////////////////////////////////////////////////////////////////////////////////////////////// - } // namespace SM90::GMMA::SPARSE - -//////////////////////////////////////////////////////////////////////////////////////////////////// - } // namespace cute + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +#include "mma_sm90_gmma_sparse_ext.hpp" +#endif diff --git a/include/cute/arch/mma_sm90_gmma_sparse_ext.hpp b/include/cute/arch/mma_sm90_gmma_sparse_ext.hpp new file mode 100644 index 000000000..c224e4034 --- /dev/null +++ b/include/cute/arch/mma_sm90_gmma_sparse_ext.hpp @@ -0,0 +1,60445 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include // CUTE_HOST_DEVICE + +#include "cutlass/arch/synclog.hpp" + +// Config +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) && defined(__CUDA_ARCH_FEAT_SM90_ALL)) +# define CUTE_ARCH_MMA_SM90A_ENABLED +#endif + +namespace cute { + +namespace SM90::GMMA::SPARSE { + +// SPARSE GMMA 64x24x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x24x32_F16F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[6]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %10, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n24k32.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5}," + " %6," + " %7," + " %8, %9," + " p, %11, %12, %13, %14;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x24x32_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x24x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x24x32_F16F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[6]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %13, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n24k32.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5}," + "{%6, %7, %8, %9}," + " %10," + " %11, %12," + " p, %14, %15, %16;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x24x32_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x40x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x40x32_F16F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[10]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %14, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n40k32.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9}," + " %10," + " %11," + " %12, %13," + " p, %15, %16, %17, %18;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x40x32_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x40x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x40x32_F16F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[10]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %17, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n40k32.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9}," + "{%10, %11, %12, %13}," + " %14," + " %15, %16," + " p, %18, %19, %20;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x40x32_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x48x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x48x32_F16F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[12]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %16, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n48k32.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + " %12," + " %13," + " %14, %15," + " p, %17, %18, %19, %20;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x48x32_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x48x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x48x32_F16F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[12]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %19, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n48k32.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + "{%12, %13, %14, %15}," + " %16," + " %17, %18," + " p, %20, %21, %22;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x48x32_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x56x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x56x32_F16F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[14]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %18, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n56k32.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13}," + " %14," + " %15," + " %16, %17," + " p, %19, %20, %21, %22;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x56x32_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x56x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x56x32_F16F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[14]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %21, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n56k32.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13}," + "{%14, %15, %16, %17}," + " %18," + " %19, %20," + " p, %22, %23, %24;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x56x32_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x72x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x72x32_F16F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[18]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %22, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n72k32.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17}," + " %18," + " %19," + " %20, %21," + " p, %23, %24, %25, %26;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x72x32_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x72x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x72x32_F16F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[18]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %25, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n72k32.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17}," + "{%18, %19, %20, %21}," + " %22," + " %23, %24," + " p, %26, %27, %28;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x72x32_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x80x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x80x32_F16F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[20]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %24, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n80k32.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19}," + " %20," + " %21," + " %22, %23," + " p, %25, %26, %27, %28;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x80x32_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x80x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x80x32_F16F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[20]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %27, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n80k32.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19}," + "{%20, %21, %22, %23}," + " %24," + " %25, %26," + " p, %28, %29, %30;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x80x32_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x88x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x88x32_F16F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[22]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %26, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n88k32.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21}," + " %22," + " %23," + " %24, %25," + " p, %27, %28, %29, %30;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x88x32_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x88x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x88x32_F16F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[22]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %29, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n88k32.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21}," + "{%22, %23, %24, %25}," + " %26," + " %27, %28," + " p, %30, %31, %32;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x88x32_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x104x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x104x32_F16F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[26]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %30, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n104k32.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25}," + " %26," + " %27," + " %28, %29," + " p, %31, %32, %33, %34;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x104x32_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x104x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x104x32_F16F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[26]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %33, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n104k32.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25}," + "{%26, %27, %28, %29}," + " %30," + " %31, %32," + " p, %34, %35, %36;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x104x32_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x112x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x112x32_F16F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[28]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %32, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n112k32.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27}," + " %28," + " %29," + " %30, %31," + " p, %33, %34, %35, %36;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x112x32_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x112x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x112x32_F16F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[28]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %35, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n112k32.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27}," + "{%28, %29, %30, %31}," + " %32," + " %33, %34," + " p, %36, %37, %38;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x112x32_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x120x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x120x32_F16F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[30]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %34, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n120k32.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29}," + " %30," + " %31," + " %32, %33," + " p, %35, %36, %37, %38;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x120x32_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x120x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x120x32_F16F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[30]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %37, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n120k32.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29}," + "{%30, %31, %32, %33}," + " %34," + " %35, %36," + " p, %38, %39, %40;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x120x32_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x136x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x136x32_F16F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[34]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %38, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n136k32.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33}," + " %34," + " %35," + " %36, %37," + " p, %39, %40, %41, %42;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x136x32_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x136x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x136x32_F16F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[34]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %41, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n136k32.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33}," + "{%34, %35, %36, %37}," + " %38," + " %39, %40," + " p, %42, %43, %44;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x136x32_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x144x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x144x32_F16F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[36]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %40, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n144k32.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35}," + " %36," + " %37," + " %38, %39," + " p, %41, %42, %43, %44;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x144x32_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x144x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x144x32_F16F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[36]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %43, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n144k32.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35}," + "{%36, %37, %38, %39}," + " %40," + " %41, %42," + " p, %44, %45, %46;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x144x32_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x152x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x152x32_F16F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[38]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %42, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n152k32.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37}," + " %38," + " %39," + " %40, %41," + " p, %43, %44, %45, %46;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x152x32_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x152x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x152x32_F16F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[38]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %45, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n152k32.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37}," + "{%38, %39, %40, %41}," + " %42," + " %43, %44," + " p, %46, %47, %48;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x152x32_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x160x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x160x32_F16F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[40]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %44, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n160k32.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + " %40," + " %41," + " %42, %43," + " p, %45, %46, %47, %48;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x160x32_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x160x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x160x32_F16F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[40]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %47, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n160k32.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + "{%40, %41, %42, %43}," + " %44," + " %45, %46," + " p, %48, %49, %50;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x160x32_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x168x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x168x32_F16F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[42]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %46, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n168k32.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41}," + " %42," + " %43," + " %44, %45," + " p, %47, %48, %49, %50;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x168x32_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x168x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x168x32_F16F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[42]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %49, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n168k32.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41}," + "{%42, %43, %44, %45}," + " %46," + " %47, %48," + " p, %50, %51, %52;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x168x32_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x176x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x176x32_F16F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[44]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %48, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n176k32.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43}," + " %44," + " %45," + " %46, %47," + " p, %49, %50, %51, %52;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x176x32_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x176x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x176x32_F16F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[44]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %51, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n176k32.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43}," + "{%44, %45, %46, %47}," + " %48," + " %49, %50," + " p, %52, %53, %54;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x176x32_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x184x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x184x32_F16F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[46]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %50, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n184k32.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45}," + " %46," + " %47," + " %48, %49," + " p, %51, %52, %53, %54;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x184x32_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x184x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x184x32_F16F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[46]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %53, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n184k32.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45}," + "{%46, %47, %48, %49}," + " %50," + " %51, %52," + " p, %54, %55, %56;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x184x32_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x200x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x200x32_F16F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[50]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %54, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n200k32.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49}," + " %50," + " %51," + " %52, %53," + " p, %55, %56, %57, %58;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x200x32_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x200x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x200x32_F16F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[50]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %57, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n200k32.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49}," + "{%50, %51, %52, %53}," + " %54," + " %55, %56," + " p, %58, %59, %60;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x200x32_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x208x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x208x32_F16F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[52]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %56, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n208k32.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51}," + " %52," + " %53," + " %54, %55," + " p, %57, %58, %59, %60;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x208x32_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x208x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x208x32_F16F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[52]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %59, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n208k32.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51}," + "{%52, %53, %54, %55}," + " %56," + " %57, %58," + " p, %60, %61, %62;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x208x32_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x216x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x216x32_F16F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[54]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %58, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n216k32.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53}," + " %54," + " %55," + " %56, %57," + " p, %59, %60, %61, %62;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x216x32_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x216x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x216x32_F16F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[54]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %61, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n216k32.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53}," + "{%54, %55, %56, %57}," + " %58," + " %59, %60," + " p, %62, %63, %64;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x216x32_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x224x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x224x32_F16F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[56]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %60, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n224k32.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + " %56," + " %57," + " %58, %59," + " p, %61, %62, %63, %64;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x224x32_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x224x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x224x32_F16F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[56]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %63, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n224k32.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + "{%56, %57, %58, %59}," + " %60," + " %61, %62," + " p, %64, %65, %66;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x224x32_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x232x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x232x32_F16F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[58]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %62, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n232k32.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57}," + " %58," + " %59," + " %60, %61," + " p, %63, %64, %65, %66;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x232x32_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x232x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x232x32_F16F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[58]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %65, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n232k32.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57}," + "{%58, %59, %60, %61}," + " %62," + " %63, %64," + " p, %66, %67, %68;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x232x32_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x240x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x240x32_F16F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[60]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %64, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n240k32.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59}," + " %60," + " %61," + " %62, %63," + " p, %65, %66, %67, %68;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x240x32_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x240x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x240x32_F16F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[60]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %67, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n240k32.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59}," + "{%60, %61, %62, %63}," + " %64," + " %65, %66," + " p, %68, %69, %70;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x240x32_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x248x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x248x32_F16F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[62]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %66, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n248k32.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61}," + " %62," + " %63," + " %64, %65," + " p, %67, %68, %69, %70;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x248x32_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x248x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x248x32_F16F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[62]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %69, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n248k32.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61}," + "{%62, %63, %64, %65}," + " %66," + " %67, %68," + " p, %70, %71, %72;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x248x32_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x24x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x24x32_F32F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[12]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %16, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n24k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + " %12," + " %13," + " %14, %15," + " p, %17, %18, %19, %20;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x24x32_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x24x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x24x32_F32F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[12]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %19, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n24k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + "{%12, %13, %14, %15}," + " %16," + " %17, %18," + " p, %20, %21, %22;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x24x32_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x40x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x40x32_F32F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[20]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %24, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n40k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19}," + " %20," + " %21," + " %22, %23," + " p, %25, %26, %27, %28;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x40x32_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x40x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x40x32_F32F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[20]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %27, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n40k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19}," + "{%20, %21, %22, %23}," + " %24," + " %25, %26," + " p, %28, %29, %30;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x40x32_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x48x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x48x32_F32F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[24]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %28, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n48k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + " %24," + " %25," + " %26, %27," + " p, %29, %30, %31, %32;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x48x32_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x48x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x48x32_F32F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[24]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %31, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n48k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + "{%24, %25, %26, %27}," + " %28," + " %29, %30," + " p, %32, %33, %34;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x48x32_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x56x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x56x32_F32F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[28]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %32, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n56k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27}," + " %28," + " %29," + " %30, %31," + " p, %33, %34, %35, %36;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x56x32_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x56x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x56x32_F32F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[28]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %35, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n56k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27}," + "{%28, %29, %30, %31}," + " %32," + " %33, %34," + " p, %36, %37, %38;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x56x32_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x72x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x72x32_F32F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[36]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %40, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n72k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35}," + " %36," + " %37," + " %38, %39," + " p, %41, %42, %43, %44;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x72x32_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x72x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x72x32_F32F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[36]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %43, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n72k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35}," + "{%36, %37, %38, %39}," + " %40," + " %41, %42," + " p, %44, %45, %46;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x72x32_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x80x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x80x32_F32F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[40]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %44, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n80k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + " %40," + " %41," + " %42, %43," + " p, %45, %46, %47, %48;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x80x32_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x80x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x80x32_F32F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[40]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %47, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n80k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + "{%40, %41, %42, %43}," + " %44," + " %45, %46," + " p, %48, %49, %50;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x80x32_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x88x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x88x32_F32F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[44]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %48, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n88k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43}," + " %44," + " %45," + " %46, %47," + " p, %49, %50, %51, %52;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x88x32_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x88x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x88x32_F32F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[44]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %51, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n88k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43}," + "{%44, %45, %46, %47}," + " %48," + " %49, %50," + " p, %52, %53, %54;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x88x32_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x104x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x104x32_F32F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[52]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %56, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n104k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51}," + " %52," + " %53," + " %54, %55," + " p, %57, %58, %59, %60;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x104x32_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x104x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x104x32_F32F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[52]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %59, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n104k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51}," + "{%52, %53, %54, %55}," + " %56," + " %57, %58," + " p, %60, %61, %62;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x104x32_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x112x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x112x32_F32F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[56]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %60, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n112k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + " %56," + " %57," + " %58, %59," + " p, %61, %62, %63, %64;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x112x32_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x112x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x112x32_F32F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[56]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %63, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n112k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + "{%56, %57, %58, %59}," + " %60," + " %61, %62," + " p, %64, %65, %66;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x112x32_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x120x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x120x32_F32F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[60]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %64, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n120k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59}," + " %60," + " %61," + " %62, %63," + " p, %65, %66, %67, %68;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x120x32_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x120x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x120x32_F32F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[60]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %67, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n120k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59}," + "{%60, %61, %62, %63}," + " %64," + " %65, %66," + " p, %68, %69, %70;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x120x32_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x136x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x136x32_F32F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[68]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %72, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n136k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67}," + " %68," + " %69," + " %70, %71," + " p, %73, %74, %75, %76;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x136x32_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x136x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x136x32_F32F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[68]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %75, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n136k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67}," + "{%68, %69, %70, %71}," + " %72," + " %73, %74," + " p, %76, %77, %78;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x136x32_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x144x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x144x32_F32F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[72]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %76, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n144k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}," + " %72," + " %73," + " %74, %75," + " p, %77, %78, %79, %80;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x144x32_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x144x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x144x32_F32F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[72]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %79, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n144k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}," + "{%72, %73, %74, %75}," + " %76," + " %77, %78," + " p, %80, %81, %82;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x144x32_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x152x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x152x32_F32F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[76]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %80, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n152k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75}," + " %76," + " %77," + " %78, %79," + " p, %81, %82, %83, %84;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x152x32_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x152x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x152x32_F32F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[76]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %83, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n152k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75}," + "{%76, %77, %78, %79}," + " %80," + " %81, %82," + " p, %84, %85, %86;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x152x32_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x160x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x160x32_F32F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[80]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %84, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n160k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}," + " %80," + " %81," + " %82, %83," + " p, %85, %86, %87, %88;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x160x32_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x160x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x160x32_F32F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[80]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %87, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n160k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}," + "{%80, %81, %82, %83}," + " %84," + " %85, %86," + " p, %88, %89, %90;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x160x32_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x168x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x168x32_F32F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[84]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %88, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n168k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83}," + " %84," + " %85," + " %86, %87," + " p, %89, %90, %91, %92;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x168x32_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x168x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x168x32_F32F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[84]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %91, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n168k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83}," + "{%84, %85, %86, %87}," + " %88," + " %89, %90," + " p, %92, %93, %94;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x168x32_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x176x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x176x32_F32F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[88]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %92, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n176k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87}," + " %88," + " %89," + " %90, %91," + " p, %93, %94, %95, %96;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x176x32_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x176x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x176x32_F32F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[88]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %95, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n176k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87}," + "{%88, %89, %90, %91}," + " %92," + " %93, %94," + " p, %96, %97, %98;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x176x32_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x184x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x184x32_F32F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[92]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + float & d88, float & d89, float & d90, float & d91, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %96, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n184k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91}," + " %92," + " %93," + " %94, %95," + " p, %97, %98, %99, %100;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), + "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x184x32_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x184x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x184x32_F32F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[92]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + float & d88, float & d89, float & d90, float & d91, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %99, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n184k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91}," + "{%92, %93, %94, %95}," + " %96," + " %97, %98," + " p, %100, %101, %102;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), + "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x184x32_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x200x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x200x32_F32F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[100]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %104, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n200k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99}," + " %100," + " %101," + " %102, %103," + " p, %105, %106, %107, %108;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x200x32_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x200x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x200x32_F32F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[100]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %107, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n200k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99}," + "{%100, %101, %102, %103}," + " %104," + " %105, %106," + " p, %108, %109, %110;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x200x32_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x208x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x208x32_F32F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[104]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %108, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n208k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103}," + " %104," + " %105," + " %106, %107," + " p, %109, %110, %111, %112;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x208x32_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x208x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x208x32_F32F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[104]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %111, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n208k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103}," + "{%104, %105, %106, %107}," + " %108," + " %109, %110," + " p, %112, %113, %114;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x208x32_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x216x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x216x32_F32F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[108]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %112, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n216k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107}," + " %108," + " %109," + " %110, %111," + " p, %113, %114, %115, %116;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x216x32_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x216x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x216x32_F32F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[108]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %115, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n216k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107}," + "{%108, %109, %110, %111}," + " %112," + " %113, %114," + " p, %116, %117, %118;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x216x32_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x224x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x224x32_F32F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[112]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %116, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n224k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111}," + " %112," + " %113," + " %114, %115," + " p, %117, %118, %119, %120;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x224x32_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x224x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x224x32_F32F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[112]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %119, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n224k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111}," + "{%112, %113, %114, %115}," + " %116," + " %117, %118," + " p, %120, %121, %122;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x224x32_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x232x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x232x32_F32F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[116]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %120, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n232k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115}," + " %116," + " %117," + " %118, %119," + " p, %121, %122, %123, %124;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x232x32_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x232x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x232x32_F32F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[116]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %123, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n232k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115}," + "{%116, %117, %118, %119}," + " %120," + " %121, %122," + " p, %124, %125, %126;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x232x32_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x240x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x240x32_F32F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[120]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %124, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n240k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119}," + " %120," + " %121," + " %122, %123," + " p, %125, %126, %127, %128;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x240x32_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x240x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x240x32_F32F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[120]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %127, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n240k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119}," + "{%120, %121, %122, %123}," + " %124," + " %125, %126," + " p, %128, %129, %130;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x240x32_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x248x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x248x32_F32F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[124]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + float & d120, float & d121, float & d122, float & d123, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %128, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n248k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123}," + " %124," + " %125," + " %126, %127," + " p, %129, %130, %131, %132;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), + "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x248x32_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x248x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x248x32_F32F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[124]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + float & d120, float & d121, float & d122, float & d123, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %131, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n248k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123}," + "{%124, %125, %126, %127}," + " %128," + " %129, %130," + " p, %132, %133, %134;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), + "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x248x32_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x24x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x24x32_F32BF16BF16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[12]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %16, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n24k32.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + " %12," + " %13," + " %14, %15," + " p, %17, %18, %19, %20;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x24x32_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x24x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x24x32_F32BF16BF16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[12]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %19, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n24k32.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + "{%12, %13, %14, %15}," + " %16," + " %17, %18," + " p, %20, %21, %22;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x24x32_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x40x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x40x32_F32BF16BF16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[20]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %24, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n40k32.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19}," + " %20," + " %21," + " %22, %23," + " p, %25, %26, %27, %28;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x40x32_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x40x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x40x32_F32BF16BF16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[20]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %27, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n40k32.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19}," + "{%20, %21, %22, %23}," + " %24," + " %25, %26," + " p, %28, %29, %30;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x40x32_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x48x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x48x32_F32BF16BF16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[24]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %28, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n48k32.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + " %24," + " %25," + " %26, %27," + " p, %29, %30, %31, %32;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x48x32_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x48x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x48x32_F32BF16BF16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[24]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %31, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n48k32.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + "{%24, %25, %26, %27}," + " %28," + " %29, %30," + " p, %32, %33, %34;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x48x32_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x56x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x56x32_F32BF16BF16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[28]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %32, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n56k32.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27}," + " %28," + " %29," + " %30, %31," + " p, %33, %34, %35, %36;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x56x32_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x56x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x56x32_F32BF16BF16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[28]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %35, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n56k32.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27}," + "{%28, %29, %30, %31}," + " %32," + " %33, %34," + " p, %36, %37, %38;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x56x32_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x72x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x72x32_F32BF16BF16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[36]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %40, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n72k32.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35}," + " %36," + " %37," + " %38, %39," + " p, %41, %42, %43, %44;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x72x32_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x72x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x72x32_F32BF16BF16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[36]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %43, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n72k32.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35}," + "{%36, %37, %38, %39}," + " %40," + " %41, %42," + " p, %44, %45, %46;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x72x32_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x80x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x80x32_F32BF16BF16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[40]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %44, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n80k32.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + " %40," + " %41," + " %42, %43," + " p, %45, %46, %47, %48;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x80x32_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x80x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x80x32_F32BF16BF16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[40]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %47, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n80k32.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + "{%40, %41, %42, %43}," + " %44," + " %45, %46," + " p, %48, %49, %50;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x80x32_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x88x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x88x32_F32BF16BF16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[44]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %48, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n88k32.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43}," + " %44," + " %45," + " %46, %47," + " p, %49, %50, %51, %52;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x88x32_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x88x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x88x32_F32BF16BF16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[44]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %51, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n88k32.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43}," + "{%44, %45, %46, %47}," + " %48," + " %49, %50," + " p, %52, %53, %54;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x88x32_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x104x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x104x32_F32BF16BF16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[52]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %56, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n104k32.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51}," + " %52," + " %53," + " %54, %55," + " p, %57, %58, %59, %60;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x104x32_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x104x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x104x32_F32BF16BF16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[52]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %59, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n104k32.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51}," + "{%52, %53, %54, %55}," + " %56," + " %57, %58," + " p, %60, %61, %62;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x104x32_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x112x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x112x32_F32BF16BF16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[56]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %60, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n112k32.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + " %56," + " %57," + " %58, %59," + " p, %61, %62, %63, %64;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x112x32_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x112x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x112x32_F32BF16BF16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[56]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %63, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n112k32.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + "{%56, %57, %58, %59}," + " %60," + " %61, %62," + " p, %64, %65, %66;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x112x32_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x120x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x120x32_F32BF16BF16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[60]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %64, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n120k32.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59}," + " %60," + " %61," + " %62, %63," + " p, %65, %66, %67, %68;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x120x32_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x120x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x120x32_F32BF16BF16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[60]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %67, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n120k32.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59}," + "{%60, %61, %62, %63}," + " %64," + " %65, %66," + " p, %68, %69, %70;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x120x32_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x136x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x136x32_F32BF16BF16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[68]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %72, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n136k32.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67}," + " %68," + " %69," + " %70, %71," + " p, %73, %74, %75, %76;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x136x32_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x136x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x136x32_F32BF16BF16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[68]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %75, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n136k32.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67}," + "{%68, %69, %70, %71}," + " %72," + " %73, %74," + " p, %76, %77, %78;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x136x32_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x144x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x144x32_F32BF16BF16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[72]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %76, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n144k32.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}," + " %72," + " %73," + " %74, %75," + " p, %77, %78, %79, %80;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x144x32_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x144x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x144x32_F32BF16BF16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[72]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %79, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n144k32.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}," + "{%72, %73, %74, %75}," + " %76," + " %77, %78," + " p, %80, %81, %82;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x144x32_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x152x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x152x32_F32BF16BF16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[76]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %80, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n152k32.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75}," + " %76," + " %77," + " %78, %79," + " p, %81, %82, %83, %84;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x152x32_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x152x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x152x32_F32BF16BF16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[76]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %83, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n152k32.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75}," + "{%76, %77, %78, %79}," + " %80," + " %81, %82," + " p, %84, %85, %86;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x152x32_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x160x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x160x32_F32BF16BF16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[80]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %84, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n160k32.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}," + " %80," + " %81," + " %82, %83," + " p, %85, %86, %87, %88;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x160x32_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x160x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x160x32_F32BF16BF16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[80]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %87, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n160k32.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}," + "{%80, %81, %82, %83}," + " %84," + " %85, %86," + " p, %88, %89, %90;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x160x32_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x168x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x168x32_F32BF16BF16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[84]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %88, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n168k32.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83}," + " %84," + " %85," + " %86, %87," + " p, %89, %90, %91, %92;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x168x32_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x168x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x168x32_F32BF16BF16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[84]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %91, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n168k32.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83}," + "{%84, %85, %86, %87}," + " %88," + " %89, %90," + " p, %92, %93, %94;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x168x32_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x176x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x176x32_F32BF16BF16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[88]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %92, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n176k32.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87}," + " %88," + " %89," + " %90, %91," + " p, %93, %94, %95, %96;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x176x32_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x176x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x176x32_F32BF16BF16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[88]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %95, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n176k32.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87}," + "{%88, %89, %90, %91}," + " %92," + " %93, %94," + " p, %96, %97, %98;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x176x32_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x184x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x184x32_F32BF16BF16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[92]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + float & d88, float & d89, float & d90, float & d91, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %96, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n184k32.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91}," + " %92," + " %93," + " %94, %95," + " p, %97, %98, %99, %100;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), + "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x184x32_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x184x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x184x32_F32BF16BF16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[92]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + float & d88, float & d89, float & d90, float & d91, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %99, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n184k32.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91}," + "{%92, %93, %94, %95}," + " %96," + " %97, %98," + " p, %100, %101, %102;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), + "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x184x32_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x200x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x200x32_F32BF16BF16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[100]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %104, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n200k32.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99}," + " %100," + " %101," + " %102, %103," + " p, %105, %106, %107, %108;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x200x32_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x200x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x200x32_F32BF16BF16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[100]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %107, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n200k32.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99}," + "{%100, %101, %102, %103}," + " %104," + " %105, %106," + " p, %108, %109, %110;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x200x32_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x208x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x208x32_F32BF16BF16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[104]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %108, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n208k32.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103}," + " %104," + " %105," + " %106, %107," + " p, %109, %110, %111, %112;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x208x32_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x208x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x208x32_F32BF16BF16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[104]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %111, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n208k32.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103}," + "{%104, %105, %106, %107}," + " %108," + " %109, %110," + " p, %112, %113, %114;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x208x32_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x216x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x216x32_F32BF16BF16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[108]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %112, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n216k32.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107}," + " %108," + " %109," + " %110, %111," + " p, %113, %114, %115, %116;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x216x32_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x216x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x216x32_F32BF16BF16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[108]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %115, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n216k32.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107}," + "{%108, %109, %110, %111}," + " %112," + " %113, %114," + " p, %116, %117, %118;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x216x32_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x224x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x224x32_F32BF16BF16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[112]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %116, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n224k32.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111}," + " %112," + " %113," + " %114, %115," + " p, %117, %118, %119, %120;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x224x32_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x224x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x224x32_F32BF16BF16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[112]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %119, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n224k32.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111}," + "{%112, %113, %114, %115}," + " %116," + " %117, %118," + " p, %120, %121, %122;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x224x32_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x232x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x232x32_F32BF16BF16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[116]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %120, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n232k32.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115}," + " %116," + " %117," + " %118, %119," + " p, %121, %122, %123, %124;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x232x32_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x232x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x232x32_F32BF16BF16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[116]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %123, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n232k32.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115}," + "{%116, %117, %118, %119}," + " %120," + " %121, %122," + " p, %124, %125, %126;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x232x32_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x240x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x240x32_F32BF16BF16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[120]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %124, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n240k32.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119}," + " %120," + " %121," + " %122, %123," + " p, %125, %126, %127, %128;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x240x32_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x240x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x240x32_F32BF16BF16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[120]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %127, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n240k32.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119}," + "{%120, %121, %122, %123}," + " %124," + " %125, %126," + " p, %128, %129, %130;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x240x32_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x248x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x248x32_F32BF16BF16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[124]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + float & d120, float & d121, float & d122, float & d123, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %128, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n248k32.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123}," + " %124," + " %125," + " %126, %127," + " p, %129, %130, %131, %132;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), + "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x248x32_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x248x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x248x32_F32BF16BF16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[124]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + float & d120, float & d121, float & d122, float & d123, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %131, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n248k32.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123}," + "{%124, %125, %126, %127}," + " %128," + " %129, %130," + " p, %132, %133, %134;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), + "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x248x32_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x24x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x24x16_F32TF32TF32_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[12]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %16, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n24k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + " %12," + " %13," + " %14, %15," + " p, %17, %18;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x24x16_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x24x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x24x16_F32TF32TF32_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[12]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %19, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n24k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + "{%12, %13, %14, %15}," + " %16," + " %17, %18," + " p, %20, %21;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x24x16_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x40x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x40x16_F32TF32TF32_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[20]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %24, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n40k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19}," + " %20," + " %21," + " %22, %23," + " p, %25, %26;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x40x16_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x40x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x40x16_F32TF32TF32_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[20]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %27, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n40k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19}," + "{%20, %21, %22, %23}," + " %24," + " %25, %26," + " p, %28, %29;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x40x16_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x48x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x48x16_F32TF32TF32_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[24]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %28, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n48k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + " %24," + " %25," + " %26, %27," + " p, %29, %30;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x48x16_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x48x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x48x16_F32TF32TF32_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[24]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %31, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n48k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + "{%24, %25, %26, %27}," + " %28," + " %29, %30," + " p, %32, %33;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x48x16_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x56x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x56x16_F32TF32TF32_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[28]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %32, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n56k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27}," + " %28," + " %29," + " %30, %31," + " p, %33, %34;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x56x16_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x56x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x56x16_F32TF32TF32_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[28]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %35, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n56k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27}," + "{%28, %29, %30, %31}," + " %32," + " %33, %34," + " p, %36, %37;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x56x16_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x72x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x72x16_F32TF32TF32_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[36]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %40, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n72k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35}," + " %36," + " %37," + " %38, %39," + " p, %41, %42;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x72x16_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x72x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x72x16_F32TF32TF32_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[36]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %43, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n72k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35}," + "{%36, %37, %38, %39}," + " %40," + " %41, %42," + " p, %44, %45;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x72x16_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x80x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x80x16_F32TF32TF32_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[40]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %44, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n80k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + " %40," + " %41," + " %42, %43," + " p, %45, %46;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x80x16_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x80x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x80x16_F32TF32TF32_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[40]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %47, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n80k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + "{%40, %41, %42, %43}," + " %44," + " %45, %46," + " p, %48, %49;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x80x16_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x88x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x88x16_F32TF32TF32_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[44]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %48, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n88k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43}," + " %44," + " %45," + " %46, %47," + " p, %49, %50;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x88x16_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x88x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x88x16_F32TF32TF32_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[44]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %51, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n88k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43}," + "{%44, %45, %46, %47}," + " %48," + " %49, %50," + " p, %52, %53;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x88x16_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x104x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x104x16_F32TF32TF32_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[52]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %56, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n104k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51}," + " %52," + " %53," + " %54, %55," + " p, %57, %58;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x104x16_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x104x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x104x16_F32TF32TF32_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[52]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %59, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n104k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51}," + "{%52, %53, %54, %55}," + " %56," + " %57, %58," + " p, %60, %61;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x104x16_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x112x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x112x16_F32TF32TF32_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[56]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %60, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n112k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + " %56," + " %57," + " %58, %59," + " p, %61, %62;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x112x16_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x112x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x112x16_F32TF32TF32_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[56]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %63, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n112k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + "{%56, %57, %58, %59}," + " %60," + " %61, %62," + " p, %64, %65;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x112x16_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x120x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x120x16_F32TF32TF32_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[60]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %64, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n120k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59}," + " %60," + " %61," + " %62, %63," + " p, %65, %66;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x120x16_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x120x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x120x16_F32TF32TF32_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[60]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %67, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n120k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59}," + "{%60, %61, %62, %63}," + " %64," + " %65, %66," + " p, %68, %69;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x120x16_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x136x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x136x16_F32TF32TF32_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[68]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %72, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n136k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67}," + " %68," + " %69," + " %70, %71," + " p, %73, %74;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x136x16_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x136x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x136x16_F32TF32TF32_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[68]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %75, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n136k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67}," + "{%68, %69, %70, %71}," + " %72," + " %73, %74," + " p, %76, %77;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x136x16_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x144x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x144x16_F32TF32TF32_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[72]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %76, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n144k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}," + " %72," + " %73," + " %74, %75," + " p, %77, %78;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x144x16_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x144x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x144x16_F32TF32TF32_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[72]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %79, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n144k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}," + "{%72, %73, %74, %75}," + " %76," + " %77, %78," + " p, %80, %81;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x144x16_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x152x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x152x16_F32TF32TF32_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[76]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %80, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n152k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75}," + " %76," + " %77," + " %78, %79," + " p, %81, %82;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x152x16_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x152x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x152x16_F32TF32TF32_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[76]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %83, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n152k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75}," + "{%76, %77, %78, %79}," + " %80," + " %81, %82," + " p, %84, %85;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x152x16_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x160x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x160x16_F32TF32TF32_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[80]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %84, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n160k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}," + " %80," + " %81," + " %82, %83," + " p, %85, %86;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x160x16_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x160x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x160x16_F32TF32TF32_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[80]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %87, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n160k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}," + "{%80, %81, %82, %83}," + " %84," + " %85, %86," + " p, %88, %89;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x160x16_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x168x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x168x16_F32TF32TF32_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[84]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %88, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n168k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83}," + " %84," + " %85," + " %86, %87," + " p, %89, %90;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x168x16_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x168x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x168x16_F32TF32TF32_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[84]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %91, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n168k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83}," + "{%84, %85, %86, %87}," + " %88," + " %89, %90," + " p, %92, %93;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x168x16_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x176x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x176x16_F32TF32TF32_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[88]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %92, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n176k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87}," + " %88," + " %89," + " %90, %91," + " p, %93, %94;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x176x16_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x176x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x176x16_F32TF32TF32_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[88]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %95, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n176k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87}," + "{%88, %89, %90, %91}," + " %92," + " %93, %94," + " p, %96, %97;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x176x16_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x184x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x184x16_F32TF32TF32_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[92]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + float & d88, float & d89, float & d90, float & d91, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %96, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n184k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91}," + " %92," + " %93," + " %94, %95," + " p, %97, %98;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), + "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x184x16_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x184x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x184x16_F32TF32TF32_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[92]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + float & d88, float & d89, float & d90, float & d91, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %99, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n184k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91}," + "{%92, %93, %94, %95}," + " %96," + " %97, %98," + " p, %100, %101;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), + "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x184x16_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x200x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x200x16_F32TF32TF32_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[100]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %104, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n200k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99}," + " %100," + " %101," + " %102, %103," + " p, %105, %106;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x200x16_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x200x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x200x16_F32TF32TF32_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[100]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %107, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n200k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99}," + "{%100, %101, %102, %103}," + " %104," + " %105, %106," + " p, %108, %109;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x200x16_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x208x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x208x16_F32TF32TF32_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[104]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %108, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n208k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103}," + " %104," + " %105," + " %106, %107," + " p, %109, %110;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x208x16_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x208x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x208x16_F32TF32TF32_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[104]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %111, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n208k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103}," + "{%104, %105, %106, %107}," + " %108," + " %109, %110," + " p, %112, %113;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x208x16_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x216x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x216x16_F32TF32TF32_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[108]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %112, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n216k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107}," + " %108," + " %109," + " %110, %111," + " p, %113, %114;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x216x16_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x216x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x216x16_F32TF32TF32_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[108]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %115, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n216k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107}," + "{%108, %109, %110, %111}," + " %112," + " %113, %114," + " p, %116, %117;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x216x16_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x224x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x224x16_F32TF32TF32_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[112]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %116, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n224k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111}," + " %112," + " %113," + " %114, %115," + " p, %117, %118;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x224x16_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x224x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x224x16_F32TF32TF32_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[112]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %119, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n224k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111}," + "{%112, %113, %114, %115}," + " %116," + " %117, %118," + " p, %120, %121;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x224x16_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x232x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x232x16_F32TF32TF32_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[116]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %120, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n232k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115}," + " %116," + " %117," + " %118, %119," + " p, %121, %122;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x232x16_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x232x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x232x16_F32TF32TF32_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[116]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %123, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n232k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115}," + "{%116, %117, %118, %119}," + " %120," + " %121, %122," + " p, %124, %125;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x232x16_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x240x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x240x16_F32TF32TF32_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[120]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %124, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n240k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119}," + " %120," + " %121," + " %122, %123," + " p, %125, %126;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x240x16_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x240x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x240x16_F32TF32TF32_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[120]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %127, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n240k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119}," + "{%120, %121, %122, %123}," + " %124," + " %125, %126," + " p, %128, %129;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x240x16_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x248x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x248x16_F32TF32TF32_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[124]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + float & d120, float & d121, float & d122, float & d123, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %128, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n248k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123}," + " %124," + " %125," + " %126, %127," + " p, %129, %130;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), + "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x248x16_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x248x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x248x16_F32TF32TF32_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[124]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + float & d120, float & d121, float & d122, float & d123, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %131, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n248k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123}," + "{%124, %125, %126, %127}," + " %128," + " %129, %130," + " p, %132, %133;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), + "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x248x16_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x24x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x24x64_S32S8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[12]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %16, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n24k64.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + " %12," + " %13," + " %14, %15," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x24x64_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x24x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x24x64_S32S8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[12]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %16, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n24k64.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + " %12," + " %13," + " %14, %15," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x24x64_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x48x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x48x64_S32S8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[24]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %28, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n48k64.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + " %24," + " %25," + " %26, %27," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x48x64_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x48x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x48x64_S32S8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[24]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %28, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n48k64.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + " %24," + " %25," + " %26, %27," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x48x64_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x80x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x80x64_S32S8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[40]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %44, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n80k64.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + " %40," + " %41," + " %42, %43," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x80x64_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x80x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x80x64_S32S8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[40]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %44, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n80k64.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + " %40," + " %41," + " %42, %43," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x80x64_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x112x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x112x64_S32S8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[56]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %60, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n112k64.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + " %56," + " %57," + " %58, %59," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x112x64_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x112x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x112x64_S32S8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[56]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %60, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n112k64.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + " %56," + " %57," + " %58, %59," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x112x64_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x144x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x144x64_S32S8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[72]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %76, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n144k64.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}," + " %72," + " %73," + " %74, %75," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x144x64_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x144x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x144x64_S32S8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[72]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %76, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n144k64.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}," + " %72," + " %73," + " %74, %75," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x144x64_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x160x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x160x64_S32S8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[80]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %84, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n160k64.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}," + " %80," + " %81," + " %82, %83," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x160x64_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x160x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x160x64_S32S8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[80]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %84, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n160k64.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}," + " %80," + " %81," + " %82, %83," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x160x64_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x176x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x176x64_S32S8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[88]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %92, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n176k64.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87}," + " %88," + " %89," + " %90, %91," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x176x64_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x176x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x176x64_S32S8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[88]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %92, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n176k64.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87}," + " %88," + " %89," + " %90, %91," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x176x64_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x208x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x208x64_S32S8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[104]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %108, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n208k64.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103}," + " %104," + " %105," + " %106, %107," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x208x64_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x208x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x208x64_S32S8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[104]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %108, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n208k64.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103}," + " %104," + " %105," + " %106, %107," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x208x64_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x224x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x224x64_S32S8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[112]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %116, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n224k64.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111}," + " %112," + " %113," + " %114, %115," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x224x64_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x224x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x224x64_S32S8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[112]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %116, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n224k64.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111}," + " %112," + " %113," + " %114, %115," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x224x64_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x240x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x240x64_S32S8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[120]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %124, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n240k64.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119}," + " %120," + " %121," + " %122, %123," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x240x64_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x240x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x240x64_S32S8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[120]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %124, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n240k64.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119}," + " %120," + " %121," + " %122, %123," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x240x64_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x24x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x24x64_S32S8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[12]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %19, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n24k64.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + "{%12, %13, %14, %15}," + " %16," + " %17, %18," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x24x64_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x24x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x24x64_S32S8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[12]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %19, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n24k64.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + "{%12, %13, %14, %15}," + " %16," + " %17, %18," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x24x64_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x48x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x48x64_S32S8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[24]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %31, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n48k64.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + "{%24, %25, %26, %27}," + " %28," + " %29, %30," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x48x64_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x48x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x48x64_S32S8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[24]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %31, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n48k64.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + "{%24, %25, %26, %27}," + " %28," + " %29, %30," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x48x64_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x80x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x80x64_S32S8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[40]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %47, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n80k64.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + "{%40, %41, %42, %43}," + " %44," + " %45, %46," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x80x64_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x80x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x80x64_S32S8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[40]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %47, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n80k64.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + "{%40, %41, %42, %43}," + " %44," + " %45, %46," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x80x64_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x112x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x112x64_S32S8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[56]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %63, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n112k64.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + "{%56, %57, %58, %59}," + " %60," + " %61, %62," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x112x64_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x112x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x112x64_S32S8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[56]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %63, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n112k64.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + "{%56, %57, %58, %59}," + " %60," + " %61, %62," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x112x64_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x144x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x144x64_S32S8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[72]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %79, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n144k64.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}," + "{%72, %73, %74, %75}," + " %76," + " %77, %78," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x144x64_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x144x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x144x64_S32S8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[72]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %79, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n144k64.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}," + "{%72, %73, %74, %75}," + " %76," + " %77, %78," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x144x64_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x160x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x160x64_S32S8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[80]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %87, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n160k64.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}," + "{%80, %81, %82, %83}," + " %84," + " %85, %86," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x160x64_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x160x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x160x64_S32S8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[80]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %87, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n160k64.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}," + "{%80, %81, %82, %83}," + " %84," + " %85, %86," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x160x64_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x176x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x176x64_S32S8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[88]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %95, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n176k64.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87}," + "{%88, %89, %90, %91}," + " %92," + " %93, %94," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x176x64_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x176x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x176x64_S32S8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[88]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %95, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n176k64.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87}," + "{%88, %89, %90, %91}," + " %92," + " %93, %94," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x176x64_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x208x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x208x64_S32S8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[104]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %111, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n208k64.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103}," + "{%104, %105, %106, %107}," + " %108," + " %109, %110," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x208x64_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x208x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x208x64_S32S8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[104]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %111, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n208k64.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103}," + "{%104, %105, %106, %107}," + " %108," + " %109, %110," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x208x64_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x224x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x224x64_S32S8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[112]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %119, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n224k64.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111}," + "{%112, %113, %114, %115}," + " %116," + " %117, %118," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x224x64_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x224x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x224x64_S32S8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[112]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %119, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n224k64.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111}," + "{%112, %113, %114, %115}," + " %116," + " %117, %118," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x224x64_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x240x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x240x64_S32S8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[120]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %127, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n240k64.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119}," + "{%120, %121, %122, %123}," + " %124," + " %125, %126," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x240x64_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x240x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x240x64_S32S8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[120]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %127, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n240k64.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119}," + "{%120, %121, %122, %123}," + " %124," + " %125, %126," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x240x64_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x24x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x24x64_S32S8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[12]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %16, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n24k64.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + " %12," + " %13," + " %14, %15," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x24x64_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x24x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x24x64_S32S8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[12]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %16, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n24k64.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + " %12," + " %13," + " %14, %15," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x24x64_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x48x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x48x64_S32S8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[24]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %28, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n48k64.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + " %24," + " %25," + " %26, %27," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x48x64_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x48x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x48x64_S32S8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[24]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %28, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n48k64.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + " %24," + " %25," + " %26, %27," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x48x64_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x80x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x80x64_S32S8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[40]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %44, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n80k64.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + " %40," + " %41," + " %42, %43," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x80x64_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x80x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x80x64_S32S8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[40]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %44, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n80k64.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + " %40," + " %41," + " %42, %43," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x80x64_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x112x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x112x64_S32S8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[56]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %60, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n112k64.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + " %56," + " %57," + " %58, %59," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x112x64_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x112x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x112x64_S32S8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[56]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %60, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n112k64.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + " %56," + " %57," + " %58, %59," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x112x64_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x144x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x144x64_S32S8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[72]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %76, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n144k64.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}," + " %72," + " %73," + " %74, %75," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x144x64_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x144x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x144x64_S32S8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[72]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %76, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n144k64.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}," + " %72," + " %73," + " %74, %75," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x144x64_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x160x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x160x64_S32S8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[80]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %84, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n160k64.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}," + " %80," + " %81," + " %82, %83," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x160x64_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x160x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x160x64_S32S8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[80]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %84, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n160k64.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}," + " %80," + " %81," + " %82, %83," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x160x64_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x176x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x176x64_S32S8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[88]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %92, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n176k64.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87}," + " %88," + " %89," + " %90, %91," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x176x64_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x176x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x176x64_S32S8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[88]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %92, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n176k64.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87}," + " %88," + " %89," + " %90, %91," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x176x64_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x208x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x208x64_S32S8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[104]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %108, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n208k64.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103}," + " %104," + " %105," + " %106, %107," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x208x64_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x208x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x208x64_S32S8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[104]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %108, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n208k64.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103}," + " %104," + " %105," + " %106, %107," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x208x64_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x224x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x224x64_S32S8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[112]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %116, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n224k64.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111}," + " %112," + " %113," + " %114, %115," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x224x64_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x224x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x224x64_S32S8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[112]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %116, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n224k64.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111}," + " %112," + " %113," + " %114, %115," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x224x64_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x240x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x240x64_S32S8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[120]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %124, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n240k64.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119}," + " %120," + " %121," + " %122, %123," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x240x64_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x240x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x240x64_S32S8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[120]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %124, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n240k64.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119}," + " %120," + " %121," + " %122, %123," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x240x64_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x24x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x24x64_S32S8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[12]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %19, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n24k64.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + "{%12, %13, %14, %15}," + " %16," + " %17, %18," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x24x64_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x24x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x24x64_S32S8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[12]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %19, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n24k64.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + "{%12, %13, %14, %15}," + " %16," + " %17, %18," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x24x64_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x48x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x48x64_S32S8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[24]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %31, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n48k64.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + "{%24, %25, %26, %27}," + " %28," + " %29, %30," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x48x64_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x48x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x48x64_S32S8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[24]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %31, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n48k64.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + "{%24, %25, %26, %27}," + " %28," + " %29, %30," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x48x64_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x80x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x80x64_S32S8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[40]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %47, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n80k64.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + "{%40, %41, %42, %43}," + " %44," + " %45, %46," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x80x64_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x80x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x80x64_S32S8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[40]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %47, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n80k64.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + "{%40, %41, %42, %43}," + " %44," + " %45, %46," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x80x64_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x112x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x112x64_S32S8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[56]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %63, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n112k64.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + "{%56, %57, %58, %59}," + " %60," + " %61, %62," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x112x64_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x112x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x112x64_S32S8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[56]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %63, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n112k64.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + "{%56, %57, %58, %59}," + " %60," + " %61, %62," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x112x64_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x144x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x144x64_S32S8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[72]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %79, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n144k64.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}," + "{%72, %73, %74, %75}," + " %76," + " %77, %78," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x144x64_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x144x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x144x64_S32S8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[72]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %79, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n144k64.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}," + "{%72, %73, %74, %75}," + " %76," + " %77, %78," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x144x64_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x160x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x160x64_S32S8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[80]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %87, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n160k64.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}," + "{%80, %81, %82, %83}," + " %84," + " %85, %86," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x160x64_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x160x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x160x64_S32S8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[80]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %87, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n160k64.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}," + "{%80, %81, %82, %83}," + " %84," + " %85, %86," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x160x64_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x176x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x176x64_S32S8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[88]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %95, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n176k64.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87}," + "{%88, %89, %90, %91}," + " %92," + " %93, %94," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x176x64_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x176x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x176x64_S32S8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[88]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %95, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n176k64.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87}," + "{%88, %89, %90, %91}," + " %92," + " %93, %94," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x176x64_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x208x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x208x64_S32S8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[104]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %111, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n208k64.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103}," + "{%104, %105, %106, %107}," + " %108," + " %109, %110," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x208x64_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x208x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x208x64_S32S8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[104]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %111, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n208k64.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103}," + "{%104, %105, %106, %107}," + " %108," + " %109, %110," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x208x64_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x224x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x224x64_S32S8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[112]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %119, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n224k64.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111}," + "{%112, %113, %114, %115}," + " %116," + " %117, %118," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x224x64_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x224x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x224x64_S32S8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[112]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %119, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n224k64.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111}," + "{%112, %113, %114, %115}," + " %116," + " %117, %118," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x224x64_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x240x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x240x64_S32S8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[120]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %127, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n240k64.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119}," + "{%120, %121, %122, %123}," + " %124," + " %125, %126," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x240x64_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x240x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x240x64_S32S8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[120]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %127, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n240k64.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119}," + "{%120, %121, %122, %123}," + " %124," + " %125, %126," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x240x64_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x24x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x24x64_S32U8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[12]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %16, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n24k64.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + " %12," + " %13," + " %14, %15," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x24x64_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x24x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x24x64_S32U8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[12]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %16, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n24k64.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + " %12," + " %13," + " %14, %15," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x24x64_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x48x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x48x64_S32U8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[24]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %28, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n48k64.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + " %24," + " %25," + " %26, %27," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x48x64_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x48x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x48x64_S32U8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[24]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %28, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n48k64.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + " %24," + " %25," + " %26, %27," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x48x64_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x80x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x80x64_S32U8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[40]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %44, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n80k64.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + " %40," + " %41," + " %42, %43," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x80x64_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x80x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x80x64_S32U8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[40]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %44, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n80k64.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + " %40," + " %41," + " %42, %43," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x80x64_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x112x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x112x64_S32U8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[56]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %60, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n112k64.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + " %56," + " %57," + " %58, %59," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x112x64_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x112x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x112x64_S32U8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[56]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %60, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n112k64.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + " %56," + " %57," + " %58, %59," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x112x64_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x144x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x144x64_S32U8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[72]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %76, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n144k64.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}," + " %72," + " %73," + " %74, %75," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x144x64_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x144x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x144x64_S32U8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[72]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %76, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n144k64.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}," + " %72," + " %73," + " %74, %75," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x144x64_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x160x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x160x64_S32U8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[80]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %84, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n160k64.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}," + " %80," + " %81," + " %82, %83," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x160x64_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x160x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x160x64_S32U8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[80]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %84, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n160k64.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}," + " %80," + " %81," + " %82, %83," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x160x64_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x176x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x176x64_S32U8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[88]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %92, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n176k64.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87}," + " %88," + " %89," + " %90, %91," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x176x64_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x176x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x176x64_S32U8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[88]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %92, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n176k64.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87}," + " %88," + " %89," + " %90, %91," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x176x64_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x208x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x208x64_S32U8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[104]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %108, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n208k64.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103}," + " %104," + " %105," + " %106, %107," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x208x64_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x208x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x208x64_S32U8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[104]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %108, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n208k64.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103}," + " %104," + " %105," + " %106, %107," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x208x64_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x224x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x224x64_S32U8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[112]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %116, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n224k64.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111}," + " %112," + " %113," + " %114, %115," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x224x64_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x224x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x224x64_S32U8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[112]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %116, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n224k64.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111}," + " %112," + " %113," + " %114, %115," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x224x64_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x240x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x240x64_S32U8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[120]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %124, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n240k64.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119}," + " %120," + " %121," + " %122, %123," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x240x64_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x240x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x240x64_S32U8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[120]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %124, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n240k64.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119}," + " %120," + " %121," + " %122, %123," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x240x64_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x24x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x24x64_S32U8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[12]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %19, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n24k64.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + "{%12, %13, %14, %15}," + " %16," + " %17, %18," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x24x64_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x24x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x24x64_S32U8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[12]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %19, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n24k64.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + "{%12, %13, %14, %15}," + " %16," + " %17, %18," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x24x64_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x48x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x48x64_S32U8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[24]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %31, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n48k64.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + "{%24, %25, %26, %27}," + " %28," + " %29, %30," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x48x64_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x48x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x48x64_S32U8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[24]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %31, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n48k64.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + "{%24, %25, %26, %27}," + " %28," + " %29, %30," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x48x64_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x80x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x80x64_S32U8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[40]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %47, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n80k64.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + "{%40, %41, %42, %43}," + " %44," + " %45, %46," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x80x64_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x80x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x80x64_S32U8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[40]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %47, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n80k64.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + "{%40, %41, %42, %43}," + " %44," + " %45, %46," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x80x64_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x112x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x112x64_S32U8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[56]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %63, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n112k64.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + "{%56, %57, %58, %59}," + " %60," + " %61, %62," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x112x64_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x112x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x112x64_S32U8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[56]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %63, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n112k64.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + "{%56, %57, %58, %59}," + " %60," + " %61, %62," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x112x64_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x144x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x144x64_S32U8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[72]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %79, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n144k64.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}," + "{%72, %73, %74, %75}," + " %76," + " %77, %78," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x144x64_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x144x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x144x64_S32U8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[72]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %79, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n144k64.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}," + "{%72, %73, %74, %75}," + " %76," + " %77, %78," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x144x64_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x160x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x160x64_S32U8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[80]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %87, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n160k64.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}," + "{%80, %81, %82, %83}," + " %84," + " %85, %86," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x160x64_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x160x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x160x64_S32U8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[80]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %87, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n160k64.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}," + "{%80, %81, %82, %83}," + " %84," + " %85, %86," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x160x64_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x176x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x176x64_S32U8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[88]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %95, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n176k64.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87}," + "{%88, %89, %90, %91}," + " %92," + " %93, %94," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x176x64_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x176x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x176x64_S32U8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[88]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %95, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n176k64.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87}," + "{%88, %89, %90, %91}," + " %92," + " %93, %94," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x176x64_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x208x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x208x64_S32U8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[104]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %111, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n208k64.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103}," + "{%104, %105, %106, %107}," + " %108," + " %109, %110," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x208x64_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x208x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x208x64_S32U8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[104]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %111, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n208k64.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103}," + "{%104, %105, %106, %107}," + " %108," + " %109, %110," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x208x64_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x224x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x224x64_S32U8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[112]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %119, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n224k64.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111}," + "{%112, %113, %114, %115}," + " %116," + " %117, %118," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x224x64_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x224x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x224x64_S32U8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[112]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %119, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n224k64.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111}," + "{%112, %113, %114, %115}," + " %116," + " %117, %118," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x224x64_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x240x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x240x64_S32U8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[120]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %127, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n240k64.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119}," + "{%120, %121, %122, %123}," + " %124," + " %125, %126," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x240x64_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x240x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x240x64_S32U8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[120]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %127, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n240k64.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119}," + "{%120, %121, %122, %123}," + " %124," + " %125, %126," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x240x64_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x24x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x24x64_S32U8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[12]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %16, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n24k64.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + " %12," + " %13," + " %14, %15," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x24x64_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x24x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x24x64_S32U8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[12]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %16, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n24k64.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + " %12," + " %13," + " %14, %15," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x24x64_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x48x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x48x64_S32U8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[24]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %28, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n48k64.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + " %24," + " %25," + " %26, %27," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x48x64_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x48x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x48x64_S32U8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[24]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %28, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n48k64.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + " %24," + " %25," + " %26, %27," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x48x64_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x80x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x80x64_S32U8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[40]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %44, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n80k64.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + " %40," + " %41," + " %42, %43," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x80x64_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x80x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x80x64_S32U8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[40]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %44, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n80k64.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + " %40," + " %41," + " %42, %43," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x80x64_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x112x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x112x64_S32U8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[56]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %60, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n112k64.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + " %56," + " %57," + " %58, %59," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x112x64_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x112x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x112x64_S32U8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[56]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %60, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n112k64.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + " %56," + " %57," + " %58, %59," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x112x64_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x144x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x144x64_S32U8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[72]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %76, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n144k64.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}," + " %72," + " %73," + " %74, %75," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x144x64_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x144x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x144x64_S32U8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[72]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %76, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n144k64.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}," + " %72," + " %73," + " %74, %75," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x144x64_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x160x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x160x64_S32U8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[80]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %84, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n160k64.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}," + " %80," + " %81," + " %82, %83," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x160x64_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x160x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x160x64_S32U8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[80]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %84, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n160k64.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}," + " %80," + " %81," + " %82, %83," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x160x64_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x176x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x176x64_S32U8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[88]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %92, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n176k64.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87}," + " %88," + " %89," + " %90, %91," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x176x64_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x176x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x176x64_S32U8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[88]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %92, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n176k64.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87}," + " %88," + " %89," + " %90, %91," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x176x64_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x208x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x208x64_S32U8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[104]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %108, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n208k64.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103}," + " %104," + " %105," + " %106, %107," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x208x64_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x208x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x208x64_S32U8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[104]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %108, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n208k64.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103}," + " %104," + " %105," + " %106, %107," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x208x64_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x224x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x224x64_S32U8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[112]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %116, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n224k64.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111}," + " %112," + " %113," + " %114, %115," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x224x64_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x224x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x224x64_S32U8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[112]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %116, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n224k64.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111}," + " %112," + " %113," + " %114, %115," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x224x64_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x240x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x240x64_S32U8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[120]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %124, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n240k64.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119}," + " %120," + " %121," + " %122, %123," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x240x64_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x240x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x240x64_S32U8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[120]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %124, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n240k64.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119}," + " %120," + " %121," + " %122, %123," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x240x64_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x24x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x24x64_S32U8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[12]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %19, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n24k64.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + "{%12, %13, %14, %15}," + " %16," + " %17, %18," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x24x64_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x24x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x24x64_S32U8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[12]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %19, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n24k64.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + "{%12, %13, %14, %15}," + " %16," + " %17, %18," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x24x64_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x48x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x48x64_S32U8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[24]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %31, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n48k64.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + "{%24, %25, %26, %27}," + " %28," + " %29, %30," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x48x64_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x48x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x48x64_S32U8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[24]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %31, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n48k64.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + "{%24, %25, %26, %27}," + " %28," + " %29, %30," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x48x64_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x80x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x80x64_S32U8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[40]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %47, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n80k64.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + "{%40, %41, %42, %43}," + " %44," + " %45, %46," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x80x64_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x80x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x80x64_S32U8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[40]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %47, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n80k64.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + "{%40, %41, %42, %43}," + " %44," + " %45, %46," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x80x64_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x112x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x112x64_S32U8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[56]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %63, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n112k64.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + "{%56, %57, %58, %59}," + " %60," + " %61, %62," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x112x64_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x112x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x112x64_S32U8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[56]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %63, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n112k64.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + "{%56, %57, %58, %59}," + " %60," + " %61, %62," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x112x64_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x144x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x144x64_S32U8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[72]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %79, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n144k64.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}," + "{%72, %73, %74, %75}," + " %76," + " %77, %78," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x144x64_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x144x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x144x64_S32U8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[72]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %79, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n144k64.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}," + "{%72, %73, %74, %75}," + " %76," + " %77, %78," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x144x64_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x160x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x160x64_S32U8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[80]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %87, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n160k64.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}," + "{%80, %81, %82, %83}," + " %84," + " %85, %86," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x160x64_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x160x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x160x64_S32U8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[80]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %87, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n160k64.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}," + "{%80, %81, %82, %83}," + " %84," + " %85, %86," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x160x64_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x176x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x176x64_S32U8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[88]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %95, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n176k64.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87}," + "{%88, %89, %90, %91}," + " %92," + " %93, %94," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x176x64_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x176x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x176x64_S32U8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[88]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %95, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n176k64.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87}," + "{%88, %89, %90, %91}," + " %92," + " %93, %94," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x176x64_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x208x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x208x64_S32U8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[104]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %111, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n208k64.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103}," + "{%104, %105, %106, %107}," + " %108," + " %109, %110," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x208x64_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x208x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x208x64_S32U8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[104]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %111, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n208k64.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103}," + "{%104, %105, %106, %107}," + " %108," + " %109, %110," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x208x64_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x224x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x224x64_S32U8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[112]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %119, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n224k64.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111}," + "{%112, %113, %114, %115}," + " %116," + " %117, %118," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x224x64_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x224x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x224x64_S32U8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[112]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %119, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n224k64.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111}," + "{%112, %113, %114, %115}," + " %116," + " %117, %118," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x224x64_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x240x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x240x64_S32U8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[120]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %127, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n240k64.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119}," + "{%120, %121, %122, %123}," + " %124," + " %125, %126," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x240x64_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x240x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x240x64_S32U8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[120]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %127, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n240k64.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119}," + "{%120, %121, %122, %123}," + " %124," + " %125, %126," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x240x64_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x24x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x24x64_F16E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[6]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %10, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n24k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5}," + " %6," + " %7," + " %8, %9," + " p, %11, %12;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x24x64_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x24x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x24x64_F16E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[6]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %13, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n24k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5}," + "{%6, %7, %8, %9}," + " %10," + " %11, %12," + " p, %14, %15;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x24x64_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x24x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x24x64_F32E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[12]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %16, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n24k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + " %12," + " %13," + " %14, %15," + " p, %17, %18;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x24x64_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x24x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x24x64_F32E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[12]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %19, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n24k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + "{%12, %13, %14, %15}," + " %16," + " %17, %18," + " p, %20, %21;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x24x64_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x40x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x40x64_F16E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[10]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %14, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n40k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9}," + " %10," + " %11," + " %12, %13," + " p, %15, %16;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x40x64_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x40x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x40x64_F16E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[10]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %17, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n40k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9}," + "{%10, %11, %12, %13}," + " %14," + " %15, %16," + " p, %18, %19;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x40x64_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x40x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x40x64_F32E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[20]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %24, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n40k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19}," + " %20," + " %21," + " %22, %23," + " p, %25, %26;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x40x64_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x40x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x40x64_F32E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[20]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %27, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n40k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19}," + "{%20, %21, %22, %23}," + " %24," + " %25, %26," + " p, %28, %29;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x40x64_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x48x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x48x64_F16E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[12]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %16, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n48k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + " %12," + " %13," + " %14, %15," + " p, %17, %18;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x48x64_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x48x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x48x64_F16E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[12]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %19, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n48k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + "{%12, %13, %14, %15}," + " %16," + " %17, %18," + " p, %20, %21;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x48x64_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x48x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x48x64_F32E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[24]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %28, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n48k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + " %24," + " %25," + " %26, %27," + " p, %29, %30;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x48x64_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x48x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x48x64_F32E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[24]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %31, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n48k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + "{%24, %25, %26, %27}," + " %28," + " %29, %30," + " p, %32, %33;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x48x64_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x56x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x56x64_F16E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[14]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %18, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n56k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13}," + " %14," + " %15," + " %16, %17," + " p, %19, %20;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x56x64_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x56x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x56x64_F16E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[14]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %21, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n56k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13}," + "{%14, %15, %16, %17}," + " %18," + " %19, %20," + " p, %22, %23;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x56x64_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x56x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x56x64_F32E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[28]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %32, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n56k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27}," + " %28," + " %29," + " %30, %31," + " p, %33, %34;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x56x64_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x56x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x56x64_F32E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[28]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %35, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n56k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27}," + "{%28, %29, %30, %31}," + " %32," + " %33, %34," + " p, %36, %37;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x56x64_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x72x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x72x64_F16E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[18]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %22, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n72k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17}," + " %18," + " %19," + " %20, %21," + " p, %23, %24;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x72x64_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x72x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x72x64_F16E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[18]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %25, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n72k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17}," + "{%18, %19, %20, %21}," + " %22," + " %23, %24," + " p, %26, %27;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x72x64_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x72x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x72x64_F32E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[36]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %40, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n72k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35}," + " %36," + " %37," + " %38, %39," + " p, %41, %42;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x72x64_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x72x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x72x64_F32E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[36]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %43, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n72k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35}," + "{%36, %37, %38, %39}," + " %40," + " %41, %42," + " p, %44, %45;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x72x64_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x80x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x80x64_F16E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[20]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %24, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n80k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19}," + " %20," + " %21," + " %22, %23," + " p, %25, %26;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x80x64_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x80x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x80x64_F16E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[20]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %27, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n80k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19}," + "{%20, %21, %22, %23}," + " %24," + " %25, %26," + " p, %28, %29;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x80x64_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x80x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x80x64_F32E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[40]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %44, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n80k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + " %40," + " %41," + " %42, %43," + " p, %45, %46;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x80x64_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x80x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x80x64_F32E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[40]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %47, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n80k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + "{%40, %41, %42, %43}," + " %44," + " %45, %46," + " p, %48, %49;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x80x64_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x88x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x88x64_F16E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[22]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %26, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n88k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21}," + " %22," + " %23," + " %24, %25," + " p, %27, %28;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x88x64_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x88x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x88x64_F16E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[22]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %29, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n88k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21}," + "{%22, %23, %24, %25}," + " %26," + " %27, %28," + " p, %30, %31;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x88x64_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x88x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x88x64_F32E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[44]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %48, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n88k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43}," + " %44," + " %45," + " %46, %47," + " p, %49, %50;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x88x64_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x88x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x88x64_F32E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[44]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %51, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n88k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43}," + "{%44, %45, %46, %47}," + " %48," + " %49, %50," + " p, %52, %53;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x88x64_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x104x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x104x64_F16E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[26]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %30, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n104k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25}," + " %26," + " %27," + " %28, %29," + " p, %31, %32;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x104x64_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x104x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x104x64_F16E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[26]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %33, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n104k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25}," + "{%26, %27, %28, %29}," + " %30," + " %31, %32," + " p, %34, %35;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x104x64_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x104x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x104x64_F32E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[52]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %56, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n104k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51}," + " %52," + " %53," + " %54, %55," + " p, %57, %58;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x104x64_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x104x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x104x64_F32E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[52]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %59, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n104k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51}," + "{%52, %53, %54, %55}," + " %56," + " %57, %58," + " p, %60, %61;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x104x64_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x112x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x112x64_F16E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[28]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %32, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n112k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27}," + " %28," + " %29," + " %30, %31," + " p, %33, %34;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x112x64_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x112x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x112x64_F16E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[28]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %35, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n112k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27}," + "{%28, %29, %30, %31}," + " %32," + " %33, %34," + " p, %36, %37;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x112x64_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x112x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x112x64_F32E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[56]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %60, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n112k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + " %56," + " %57," + " %58, %59," + " p, %61, %62;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x112x64_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x112x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x112x64_F32E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[56]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %63, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n112k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + "{%56, %57, %58, %59}," + " %60," + " %61, %62," + " p, %64, %65;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x112x64_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x120x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x120x64_F16E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[30]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %34, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n120k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29}," + " %30," + " %31," + " %32, %33," + " p, %35, %36;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x120x64_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x120x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x120x64_F16E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[30]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %37, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n120k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29}," + "{%30, %31, %32, %33}," + " %34," + " %35, %36," + " p, %38, %39;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x120x64_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x120x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x120x64_F32E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[60]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %64, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n120k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59}," + " %60," + " %61," + " %62, %63," + " p, %65, %66;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x120x64_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x120x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x120x64_F32E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[60]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %67, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n120k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59}," + "{%60, %61, %62, %63}," + " %64," + " %65, %66," + " p, %68, %69;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x120x64_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x136x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x136x64_F16E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[34]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %38, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n136k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33}," + " %34," + " %35," + " %36, %37," + " p, %39, %40;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x136x64_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x136x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x136x64_F16E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[34]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %41, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n136k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33}," + "{%34, %35, %36, %37}," + " %38," + " %39, %40," + " p, %42, %43;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x136x64_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x136x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x136x64_F32E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[68]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %72, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n136k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67}," + " %68," + " %69," + " %70, %71," + " p, %73, %74;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x136x64_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x136x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x136x64_F32E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[68]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %75, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n136k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67}," + "{%68, %69, %70, %71}," + " %72," + " %73, %74," + " p, %76, %77;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x136x64_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x144x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x144x64_F16E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[36]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %40, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n144k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35}," + " %36," + " %37," + " %38, %39," + " p, %41, %42;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x144x64_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x144x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x144x64_F16E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[36]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %43, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n144k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35}," + "{%36, %37, %38, %39}," + " %40," + " %41, %42," + " p, %44, %45;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x144x64_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x144x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x144x64_F32E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[72]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %76, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n144k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}," + " %72," + " %73," + " %74, %75," + " p, %77, %78;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x144x64_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x144x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x144x64_F32E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[72]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %79, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n144k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}," + "{%72, %73, %74, %75}," + " %76," + " %77, %78," + " p, %80, %81;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x144x64_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x152x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x152x64_F16E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[38]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %42, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n152k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37}," + " %38," + " %39," + " %40, %41," + " p, %43, %44;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x152x64_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x152x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x152x64_F16E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[38]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %45, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n152k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37}," + "{%38, %39, %40, %41}," + " %42," + " %43, %44," + " p, %46, %47;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x152x64_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x152x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x152x64_F32E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[76]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %80, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n152k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75}," + " %76," + " %77," + " %78, %79," + " p, %81, %82;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x152x64_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x152x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x152x64_F32E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[76]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %83, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n152k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75}," + "{%76, %77, %78, %79}," + " %80," + " %81, %82," + " p, %84, %85;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x152x64_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x160x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x160x64_F16E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[40]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %44, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n160k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + " %40," + " %41," + " %42, %43," + " p, %45, %46;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x160x64_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x160x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x160x64_F16E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[40]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %47, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n160k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + "{%40, %41, %42, %43}," + " %44," + " %45, %46," + " p, %48, %49;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x160x64_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x160x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x160x64_F32E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[80]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %84, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n160k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}," + " %80," + " %81," + " %82, %83," + " p, %85, %86;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x160x64_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x160x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x160x64_F32E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[80]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %87, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n160k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}," + "{%80, %81, %82, %83}," + " %84," + " %85, %86," + " p, %88, %89;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x160x64_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x168x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x168x64_F16E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[42]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %46, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n168k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41}," + " %42," + " %43," + " %44, %45," + " p, %47, %48;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x168x64_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x168x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x168x64_F16E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[42]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %49, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n168k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41}," + "{%42, %43, %44, %45}," + " %46," + " %47, %48," + " p, %50, %51;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x168x64_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x168x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x168x64_F32E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[84]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %88, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n168k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83}," + " %84," + " %85," + " %86, %87," + " p, %89, %90;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x168x64_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x168x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x168x64_F32E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[84]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %91, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n168k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83}," + "{%84, %85, %86, %87}," + " %88," + " %89, %90," + " p, %92, %93;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x168x64_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x176x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x176x64_F16E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[44]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %48, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n176k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43}," + " %44," + " %45," + " %46, %47," + " p, %49, %50;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x176x64_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x176x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x176x64_F16E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[44]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %51, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n176k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43}," + "{%44, %45, %46, %47}," + " %48," + " %49, %50," + " p, %52, %53;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x176x64_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x176x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x176x64_F32E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[88]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %92, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n176k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87}," + " %88," + " %89," + " %90, %91," + " p, %93, %94;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x176x64_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x176x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x176x64_F32E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[88]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %95, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n176k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87}," + "{%88, %89, %90, %91}," + " %92," + " %93, %94," + " p, %96, %97;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x176x64_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x184x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x184x64_F16E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[46]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %50, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n184k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45}," + " %46," + " %47," + " %48, %49," + " p, %51, %52;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x184x64_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x184x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x184x64_F16E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[46]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %53, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n184k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45}," + "{%46, %47, %48, %49}," + " %50," + " %51, %52," + " p, %54, %55;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x184x64_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x184x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x184x64_F32E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[92]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + float & d88, float & d89, float & d90, float & d91, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %96, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n184k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91}," + " %92," + " %93," + " %94, %95," + " p, %97, %98;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), + "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x184x64_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x184x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x184x64_F32E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[92]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + float & d88, float & d89, float & d90, float & d91, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %99, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n184k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91}," + "{%92, %93, %94, %95}," + " %96," + " %97, %98," + " p, %100, %101;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), + "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x184x64_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x200x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x200x64_F16E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[50]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %54, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n200k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49}," + " %50," + " %51," + " %52, %53," + " p, %55, %56;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x200x64_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x200x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x200x64_F16E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[50]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %57, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n200k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49}," + "{%50, %51, %52, %53}," + " %54," + " %55, %56," + " p, %58, %59;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x200x64_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x200x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x200x64_F32E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[100]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %104, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n200k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99}," + " %100," + " %101," + " %102, %103," + " p, %105, %106;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x200x64_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x200x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x200x64_F32E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[100]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %107, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n200k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99}," + "{%100, %101, %102, %103}," + " %104," + " %105, %106," + " p, %108, %109;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x200x64_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x208x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x208x64_F16E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[52]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %56, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n208k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51}," + " %52," + " %53," + " %54, %55," + " p, %57, %58;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x208x64_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x208x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x208x64_F16E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[52]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %59, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n208k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51}," + "{%52, %53, %54, %55}," + " %56," + " %57, %58," + " p, %60, %61;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x208x64_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x208x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x208x64_F32E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[104]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %108, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n208k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103}," + " %104," + " %105," + " %106, %107," + " p, %109, %110;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x208x64_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x208x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x208x64_F32E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[104]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %111, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n208k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103}," + "{%104, %105, %106, %107}," + " %108," + " %109, %110," + " p, %112, %113;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x208x64_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x216x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x216x64_F16E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[54]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %58, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n216k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53}," + " %54," + " %55," + " %56, %57," + " p, %59, %60;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x216x64_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x216x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x216x64_F16E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[54]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %61, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n216k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53}," + "{%54, %55, %56, %57}," + " %58," + " %59, %60," + " p, %62, %63;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x216x64_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x216x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x216x64_F32E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[108]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %112, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n216k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107}," + " %108," + " %109," + " %110, %111," + " p, %113, %114;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x216x64_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x216x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x216x64_F32E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[108]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %115, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n216k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107}," + "{%108, %109, %110, %111}," + " %112," + " %113, %114," + " p, %116, %117;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x216x64_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x224x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x224x64_F16E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[56]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %60, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n224k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + " %56," + " %57," + " %58, %59," + " p, %61, %62;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x224x64_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x224x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x224x64_F16E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[56]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %63, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n224k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + "{%56, %57, %58, %59}," + " %60," + " %61, %62," + " p, %64, %65;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x224x64_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x224x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x224x64_F32E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[112]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %116, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n224k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111}," + " %112," + " %113," + " %114, %115," + " p, %117, %118;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x224x64_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x224x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x224x64_F32E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[112]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %119, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n224k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111}," + "{%112, %113, %114, %115}," + " %116," + " %117, %118," + " p, %120, %121;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x224x64_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x232x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x232x64_F16E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[58]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %62, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n232k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57}," + " %58," + " %59," + " %60, %61," + " p, %63, %64;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x232x64_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x232x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x232x64_F16E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[58]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %65, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n232k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57}," + "{%58, %59, %60, %61}," + " %62," + " %63, %64," + " p, %66, %67;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x232x64_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x232x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x232x64_F32E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[116]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %120, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n232k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115}," + " %116," + " %117," + " %118, %119," + " p, %121, %122;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x232x64_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x232x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x232x64_F32E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[116]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %123, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n232k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115}," + "{%116, %117, %118, %119}," + " %120," + " %121, %122," + " p, %124, %125;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x232x64_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x240x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x240x64_F16E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[60]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %64, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n240k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59}," + " %60," + " %61," + " %62, %63," + " p, %65, %66;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x240x64_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x240x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x240x64_F16E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[60]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %67, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n240k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59}," + "{%60, %61, %62, %63}," + " %64," + " %65, %66," + " p, %68, %69;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x240x64_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x240x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x240x64_F32E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[120]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %124, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n240k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119}," + " %120," + " %121," + " %122, %123," + " p, %125, %126;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x240x64_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x240x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x240x64_F32E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[120]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %127, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n240k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119}," + "{%120, %121, %122, %123}," + " %124," + " %125, %126," + " p, %128, %129;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x240x64_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x248x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x248x64_F16E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[62]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %66, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n248k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61}," + " %62," + " %63," + " %64, %65," + " p, %67, %68;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x248x64_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x248x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x248x64_F16E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[62]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %69, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n248k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61}," + "{%62, %63, %64, %65}," + " %66," + " %67, %68," + " p, %70, %71;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x248x64_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x248x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x248x64_F32E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[124]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + float & d120, float & d121, float & d122, float & d123, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %128, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n248k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123}," + " %124," + " %125," + " %126, %127," + " p, %129, %130;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), + "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x248x64_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x248x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x248x64_F32E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[124]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + float & d120, float & d121, float & d122, float & d123, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %131, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n248k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123}," + "{%124, %125, %126, %127}," + " %128," + " %129, %130," + " p, %132, %133;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), + "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x248x64_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x24x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x24x64_F16E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[6]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %10, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n24k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5}," + " %6," + " %7," + " %8, %9," + " p, %11, %12;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x24x64_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x24x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x24x64_F16E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[6]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %13, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n24k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5}," + "{%6, %7, %8, %9}," + " %10," + " %11, %12," + " p, %14, %15;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x24x64_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x24x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x24x64_F32E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[12]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %16, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n24k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + " %12," + " %13," + " %14, %15," + " p, %17, %18;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x24x64_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x24x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x24x64_F32E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[12]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %19, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n24k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + "{%12, %13, %14, %15}," + " %16," + " %17, %18," + " p, %20, %21;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x24x64_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x40x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x40x64_F16E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[10]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %14, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n40k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9}," + " %10," + " %11," + " %12, %13," + " p, %15, %16;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x40x64_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x40x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x40x64_F16E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[10]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %17, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n40k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9}," + "{%10, %11, %12, %13}," + " %14," + " %15, %16," + " p, %18, %19;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x40x64_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x40x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x40x64_F32E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[20]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %24, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n40k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19}," + " %20," + " %21," + " %22, %23," + " p, %25, %26;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x40x64_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x40x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x40x64_F32E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[20]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %27, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n40k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19}," + "{%20, %21, %22, %23}," + " %24," + " %25, %26," + " p, %28, %29;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x40x64_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x48x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x48x64_F16E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[12]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %16, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n48k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + " %12," + " %13," + " %14, %15," + " p, %17, %18;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x48x64_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x48x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x48x64_F16E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[12]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %19, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n48k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + "{%12, %13, %14, %15}," + " %16," + " %17, %18," + " p, %20, %21;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x48x64_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x48x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x48x64_F32E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[24]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %28, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n48k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + " %24," + " %25," + " %26, %27," + " p, %29, %30;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x48x64_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x48x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x48x64_F32E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[24]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %31, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n48k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + "{%24, %25, %26, %27}," + " %28," + " %29, %30," + " p, %32, %33;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x48x64_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x56x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x56x64_F16E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[14]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %18, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n56k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13}," + " %14," + " %15," + " %16, %17," + " p, %19, %20;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x56x64_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x56x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x56x64_F16E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[14]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %21, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n56k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13}," + "{%14, %15, %16, %17}," + " %18," + " %19, %20," + " p, %22, %23;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x56x64_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x56x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x56x64_F32E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[28]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %32, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n56k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27}," + " %28," + " %29," + " %30, %31," + " p, %33, %34;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x56x64_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x56x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x56x64_F32E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[28]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %35, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n56k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27}," + "{%28, %29, %30, %31}," + " %32," + " %33, %34," + " p, %36, %37;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x56x64_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x72x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x72x64_F16E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[18]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %22, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n72k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17}," + " %18," + " %19," + " %20, %21," + " p, %23, %24;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x72x64_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x72x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x72x64_F16E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[18]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %25, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n72k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17}," + "{%18, %19, %20, %21}," + " %22," + " %23, %24," + " p, %26, %27;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x72x64_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x72x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x72x64_F32E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[36]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %40, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n72k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35}," + " %36," + " %37," + " %38, %39," + " p, %41, %42;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x72x64_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x72x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x72x64_F32E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[36]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %43, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n72k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35}," + "{%36, %37, %38, %39}," + " %40," + " %41, %42," + " p, %44, %45;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x72x64_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x80x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x80x64_F16E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[20]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %24, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n80k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19}," + " %20," + " %21," + " %22, %23," + " p, %25, %26;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x80x64_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x80x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x80x64_F16E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[20]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %27, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n80k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19}," + "{%20, %21, %22, %23}," + " %24," + " %25, %26," + " p, %28, %29;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x80x64_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x80x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x80x64_F32E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[40]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %44, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n80k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + " %40," + " %41," + " %42, %43," + " p, %45, %46;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x80x64_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x80x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x80x64_F32E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[40]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %47, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n80k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + "{%40, %41, %42, %43}," + " %44," + " %45, %46," + " p, %48, %49;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x80x64_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x88x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x88x64_F16E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[22]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %26, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n88k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21}," + " %22," + " %23," + " %24, %25," + " p, %27, %28;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x88x64_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x88x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x88x64_F16E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[22]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %29, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n88k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21}," + "{%22, %23, %24, %25}," + " %26," + " %27, %28," + " p, %30, %31;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x88x64_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x88x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x88x64_F32E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[44]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %48, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n88k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43}," + " %44," + " %45," + " %46, %47," + " p, %49, %50;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x88x64_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x88x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x88x64_F32E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[44]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %51, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n88k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43}," + "{%44, %45, %46, %47}," + " %48," + " %49, %50," + " p, %52, %53;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x88x64_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x104x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x104x64_F16E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[26]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %30, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n104k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25}," + " %26," + " %27," + " %28, %29," + " p, %31, %32;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x104x64_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x104x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x104x64_F16E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[26]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %33, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n104k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25}," + "{%26, %27, %28, %29}," + " %30," + " %31, %32," + " p, %34, %35;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x104x64_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x104x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x104x64_F32E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[52]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %56, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n104k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51}," + " %52," + " %53," + " %54, %55," + " p, %57, %58;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x104x64_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x104x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x104x64_F32E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[52]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %59, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n104k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51}," + "{%52, %53, %54, %55}," + " %56," + " %57, %58," + " p, %60, %61;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x104x64_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x112x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x112x64_F16E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[28]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %32, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n112k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27}," + " %28," + " %29," + " %30, %31," + " p, %33, %34;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x112x64_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x112x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x112x64_F16E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[28]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %35, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n112k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27}," + "{%28, %29, %30, %31}," + " %32," + " %33, %34," + " p, %36, %37;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x112x64_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x112x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x112x64_F32E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[56]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %60, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n112k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + " %56," + " %57," + " %58, %59," + " p, %61, %62;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x112x64_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x112x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x112x64_F32E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[56]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %63, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n112k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + "{%56, %57, %58, %59}," + " %60," + " %61, %62," + " p, %64, %65;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x112x64_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x120x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x120x64_F16E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[30]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %34, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n120k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29}," + " %30," + " %31," + " %32, %33," + " p, %35, %36;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x120x64_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x120x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x120x64_F16E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[30]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %37, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n120k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29}," + "{%30, %31, %32, %33}," + " %34," + " %35, %36," + " p, %38, %39;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x120x64_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x120x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x120x64_F32E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[60]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %64, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n120k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59}," + " %60," + " %61," + " %62, %63," + " p, %65, %66;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x120x64_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x120x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x120x64_F32E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[60]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %67, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n120k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59}," + "{%60, %61, %62, %63}," + " %64," + " %65, %66," + " p, %68, %69;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x120x64_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x136x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x136x64_F16E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[34]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %38, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n136k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33}," + " %34," + " %35," + " %36, %37," + " p, %39, %40;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x136x64_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x136x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x136x64_F16E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[34]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %41, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n136k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33}," + "{%34, %35, %36, %37}," + " %38," + " %39, %40," + " p, %42, %43;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x136x64_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x136x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x136x64_F32E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[68]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %72, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n136k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67}," + " %68," + " %69," + " %70, %71," + " p, %73, %74;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x136x64_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x136x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x136x64_F32E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[68]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %75, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n136k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67}," + "{%68, %69, %70, %71}," + " %72," + " %73, %74," + " p, %76, %77;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x136x64_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x144x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x144x64_F16E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[36]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %40, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n144k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35}," + " %36," + " %37," + " %38, %39," + " p, %41, %42;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x144x64_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x144x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x144x64_F16E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[36]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %43, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n144k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35}," + "{%36, %37, %38, %39}," + " %40," + " %41, %42," + " p, %44, %45;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x144x64_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x144x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x144x64_F32E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[72]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %76, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n144k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}," + " %72," + " %73," + " %74, %75," + " p, %77, %78;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x144x64_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x144x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x144x64_F32E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[72]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %79, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n144k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}," + "{%72, %73, %74, %75}," + " %76," + " %77, %78," + " p, %80, %81;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x144x64_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x152x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x152x64_F16E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[38]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %42, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n152k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37}," + " %38," + " %39," + " %40, %41," + " p, %43, %44;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x152x64_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x152x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x152x64_F16E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[38]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %45, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n152k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37}," + "{%38, %39, %40, %41}," + " %42," + " %43, %44," + " p, %46, %47;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x152x64_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x152x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x152x64_F32E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[76]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %80, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n152k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75}," + " %76," + " %77," + " %78, %79," + " p, %81, %82;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x152x64_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x152x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x152x64_F32E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[76]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %83, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n152k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75}," + "{%76, %77, %78, %79}," + " %80," + " %81, %82," + " p, %84, %85;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x152x64_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x160x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x160x64_F16E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[40]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %44, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n160k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + " %40," + " %41," + " %42, %43," + " p, %45, %46;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x160x64_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x160x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x160x64_F16E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[40]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %47, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n160k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + "{%40, %41, %42, %43}," + " %44," + " %45, %46," + " p, %48, %49;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x160x64_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x160x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x160x64_F32E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[80]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %84, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n160k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}," + " %80," + " %81," + " %82, %83," + " p, %85, %86;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x160x64_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x160x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x160x64_F32E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[80]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %87, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n160k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}," + "{%80, %81, %82, %83}," + " %84," + " %85, %86," + " p, %88, %89;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x160x64_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x168x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x168x64_F16E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[42]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %46, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n168k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41}," + " %42," + " %43," + " %44, %45," + " p, %47, %48;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x168x64_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x168x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x168x64_F16E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[42]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %49, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n168k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41}," + "{%42, %43, %44, %45}," + " %46," + " %47, %48," + " p, %50, %51;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x168x64_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x168x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x168x64_F32E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[84]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %88, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n168k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83}," + " %84," + " %85," + " %86, %87," + " p, %89, %90;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x168x64_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x168x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x168x64_F32E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[84]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %91, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n168k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83}," + "{%84, %85, %86, %87}," + " %88," + " %89, %90," + " p, %92, %93;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x168x64_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x176x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x176x64_F16E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[44]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %48, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n176k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43}," + " %44," + " %45," + " %46, %47," + " p, %49, %50;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x176x64_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x176x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x176x64_F16E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[44]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %51, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n176k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43}," + "{%44, %45, %46, %47}," + " %48," + " %49, %50," + " p, %52, %53;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x176x64_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x176x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x176x64_F32E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[88]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %92, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n176k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87}," + " %88," + " %89," + " %90, %91," + " p, %93, %94;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x176x64_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x176x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x176x64_F32E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[88]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %95, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n176k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87}," + "{%88, %89, %90, %91}," + " %92," + " %93, %94," + " p, %96, %97;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x176x64_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x184x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x184x64_F16E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[46]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %50, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n184k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45}," + " %46," + " %47," + " %48, %49," + " p, %51, %52;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x184x64_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x184x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x184x64_F16E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[46]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %53, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n184k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45}," + "{%46, %47, %48, %49}," + " %50," + " %51, %52," + " p, %54, %55;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x184x64_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x184x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x184x64_F32E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[92]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + float & d88, float & d89, float & d90, float & d91, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %96, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n184k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91}," + " %92," + " %93," + " %94, %95," + " p, %97, %98;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), + "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x184x64_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x184x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x184x64_F32E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[92]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + float & d88, float & d89, float & d90, float & d91, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %99, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n184k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91}," + "{%92, %93, %94, %95}," + " %96," + " %97, %98," + " p, %100, %101;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), + "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x184x64_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x200x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x200x64_F16E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[50]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %54, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n200k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49}," + " %50," + " %51," + " %52, %53," + " p, %55, %56;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x200x64_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x200x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x200x64_F16E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[50]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %57, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n200k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49}," + "{%50, %51, %52, %53}," + " %54," + " %55, %56," + " p, %58, %59;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x200x64_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x200x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x200x64_F32E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[100]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %104, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n200k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99}," + " %100," + " %101," + " %102, %103," + " p, %105, %106;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x200x64_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x200x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x200x64_F32E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[100]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %107, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n200k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99}," + "{%100, %101, %102, %103}," + " %104," + " %105, %106," + " p, %108, %109;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x200x64_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x208x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x208x64_F16E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[52]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %56, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n208k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51}," + " %52," + " %53," + " %54, %55," + " p, %57, %58;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x208x64_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x208x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x208x64_F16E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[52]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %59, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n208k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51}," + "{%52, %53, %54, %55}," + " %56," + " %57, %58," + " p, %60, %61;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x208x64_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x208x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x208x64_F32E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[104]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %108, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n208k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103}," + " %104," + " %105," + " %106, %107," + " p, %109, %110;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x208x64_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x208x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x208x64_F32E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[104]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %111, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n208k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103}," + "{%104, %105, %106, %107}," + " %108," + " %109, %110," + " p, %112, %113;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x208x64_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x216x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x216x64_F16E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[54]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %58, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n216k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53}," + " %54," + " %55," + " %56, %57," + " p, %59, %60;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x216x64_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x216x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x216x64_F16E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[54]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %61, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n216k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53}," + "{%54, %55, %56, %57}," + " %58," + " %59, %60," + " p, %62, %63;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x216x64_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x216x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x216x64_F32E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[108]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %112, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n216k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107}," + " %108," + " %109," + " %110, %111," + " p, %113, %114;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x216x64_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x216x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x216x64_F32E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[108]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %115, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n216k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107}," + "{%108, %109, %110, %111}," + " %112," + " %113, %114," + " p, %116, %117;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x216x64_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x224x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x224x64_F16E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[56]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %60, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n224k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + " %56," + " %57," + " %58, %59," + " p, %61, %62;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x224x64_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x224x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x224x64_F16E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[56]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %63, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n224k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + "{%56, %57, %58, %59}," + " %60," + " %61, %62," + " p, %64, %65;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x224x64_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x224x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x224x64_F32E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[112]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %116, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n224k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111}," + " %112," + " %113," + " %114, %115," + " p, %117, %118;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x224x64_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x224x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x224x64_F32E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[112]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %119, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n224k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111}," + "{%112, %113, %114, %115}," + " %116," + " %117, %118," + " p, %120, %121;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x224x64_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x232x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x232x64_F16E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[58]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %62, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n232k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57}," + " %58," + " %59," + " %60, %61," + " p, %63, %64;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x232x64_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x232x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x232x64_F16E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[58]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %65, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n232k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57}," + "{%58, %59, %60, %61}," + " %62," + " %63, %64," + " p, %66, %67;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x232x64_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x232x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x232x64_F32E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[116]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %120, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n232k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115}," + " %116," + " %117," + " %118, %119," + " p, %121, %122;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x232x64_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x232x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x232x64_F32E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[116]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %123, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n232k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115}," + "{%116, %117, %118, %119}," + " %120," + " %121, %122," + " p, %124, %125;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x232x64_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x240x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x240x64_F16E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[60]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %64, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n240k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59}," + " %60," + " %61," + " %62, %63," + " p, %65, %66;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x240x64_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x240x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x240x64_F16E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[60]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %67, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n240k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59}," + "{%60, %61, %62, %63}," + " %64," + " %65, %66," + " p, %68, %69;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x240x64_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x240x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x240x64_F32E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[120]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %124, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n240k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119}," + " %120," + " %121," + " %122, %123," + " p, %125, %126;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x240x64_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x240x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x240x64_F32E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[120]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %127, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n240k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119}," + "{%120, %121, %122, %123}," + " %124," + " %125, %126," + " p, %128, %129;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x240x64_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x248x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x248x64_F16E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[62]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %66, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n248k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61}," + " %62," + " %63," + " %64, %65," + " p, %67, %68;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x248x64_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x248x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x248x64_F16E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[62]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %69, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n248k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61}," + "{%62, %63, %64, %65}," + " %66," + " %67, %68," + " p, %70, %71;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x248x64_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x248x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x248x64_F32E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[124]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + float & d120, float & d121, float & d122, float & d123, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %128, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n248k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123}," + " %124," + " %125," + " %126, %127," + " p, %129, %130;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), + "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x248x64_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x248x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x248x64_F32E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[124]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + float & d120, float & d121, float & d122, float & d123, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %131, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n248k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123}," + "{%124, %125, %126, %127}," + " %128," + " %129, %130," + " p, %132, %133;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), + "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x248x64_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x24x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x24x64_F16E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[6]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %10, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n24k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5}," + " %6," + " %7," + " %8, %9," + " p, %11, %12;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x24x64_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x24x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x24x64_F16E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[6]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %13, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n24k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5}," + "{%6, %7, %8, %9}," + " %10," + " %11, %12," + " p, %14, %15;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x24x64_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x24x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x24x64_F32E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[12]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %16, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n24k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + " %12," + " %13," + " %14, %15," + " p, %17, %18;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x24x64_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x24x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x24x64_F32E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[12]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %19, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n24k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + "{%12, %13, %14, %15}," + " %16," + " %17, %18," + " p, %20, %21;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x24x64_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x40x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x40x64_F16E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[10]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %14, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n40k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9}," + " %10," + " %11," + " %12, %13," + " p, %15, %16;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x40x64_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x40x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x40x64_F16E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[10]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %17, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n40k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9}," + "{%10, %11, %12, %13}," + " %14," + " %15, %16," + " p, %18, %19;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x40x64_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x40x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x40x64_F32E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[20]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %24, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n40k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19}," + " %20," + " %21," + " %22, %23," + " p, %25, %26;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x40x64_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x40x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x40x64_F32E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[20]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %27, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n40k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19}," + "{%20, %21, %22, %23}," + " %24," + " %25, %26," + " p, %28, %29;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x40x64_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x48x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x48x64_F16E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[12]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %16, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n48k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + " %12," + " %13," + " %14, %15," + " p, %17, %18;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x48x64_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x48x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x48x64_F16E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[12]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %19, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n48k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + "{%12, %13, %14, %15}," + " %16," + " %17, %18," + " p, %20, %21;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x48x64_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x48x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x48x64_F32E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[24]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %28, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n48k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + " %24," + " %25," + " %26, %27," + " p, %29, %30;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x48x64_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x48x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x48x64_F32E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[24]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %31, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n48k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + "{%24, %25, %26, %27}," + " %28," + " %29, %30," + " p, %32, %33;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x48x64_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x56x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x56x64_F16E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[14]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %18, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n56k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13}," + " %14," + " %15," + " %16, %17," + " p, %19, %20;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x56x64_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x56x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x56x64_F16E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[14]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %21, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n56k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13}," + "{%14, %15, %16, %17}," + " %18," + " %19, %20," + " p, %22, %23;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x56x64_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x56x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x56x64_F32E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[28]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %32, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n56k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27}," + " %28," + " %29," + " %30, %31," + " p, %33, %34;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x56x64_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x56x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x56x64_F32E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[28]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %35, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n56k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27}," + "{%28, %29, %30, %31}," + " %32," + " %33, %34," + " p, %36, %37;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x56x64_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x72x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x72x64_F16E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[18]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %22, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n72k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17}," + " %18," + " %19," + " %20, %21," + " p, %23, %24;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x72x64_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x72x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x72x64_F16E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[18]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %25, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n72k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17}," + "{%18, %19, %20, %21}," + " %22," + " %23, %24," + " p, %26, %27;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x72x64_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x72x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x72x64_F32E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[36]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %40, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n72k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35}," + " %36," + " %37," + " %38, %39," + " p, %41, %42;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x72x64_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x72x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x72x64_F32E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[36]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %43, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n72k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35}," + "{%36, %37, %38, %39}," + " %40," + " %41, %42," + " p, %44, %45;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x72x64_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x80x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x80x64_F16E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[20]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %24, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n80k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19}," + " %20," + " %21," + " %22, %23," + " p, %25, %26;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x80x64_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x80x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x80x64_F16E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[20]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %27, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n80k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19}," + "{%20, %21, %22, %23}," + " %24," + " %25, %26," + " p, %28, %29;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x80x64_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x80x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x80x64_F32E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[40]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %44, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n80k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + " %40," + " %41," + " %42, %43," + " p, %45, %46;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x80x64_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x80x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x80x64_F32E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[40]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %47, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n80k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + "{%40, %41, %42, %43}," + " %44," + " %45, %46," + " p, %48, %49;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x80x64_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x88x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x88x64_F16E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[22]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %26, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n88k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21}," + " %22," + " %23," + " %24, %25," + " p, %27, %28;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x88x64_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x88x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x88x64_F16E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[22]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %29, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n88k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21}," + "{%22, %23, %24, %25}," + " %26," + " %27, %28," + " p, %30, %31;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x88x64_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x88x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x88x64_F32E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[44]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %48, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n88k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43}," + " %44," + " %45," + " %46, %47," + " p, %49, %50;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x88x64_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x88x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x88x64_F32E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[44]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %51, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n88k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43}," + "{%44, %45, %46, %47}," + " %48," + " %49, %50," + " p, %52, %53;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x88x64_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x104x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x104x64_F16E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[26]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %30, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n104k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25}," + " %26," + " %27," + " %28, %29," + " p, %31, %32;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x104x64_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x104x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x104x64_F16E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[26]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %33, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n104k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25}," + "{%26, %27, %28, %29}," + " %30," + " %31, %32," + " p, %34, %35;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x104x64_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x104x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x104x64_F32E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[52]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %56, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n104k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51}," + " %52," + " %53," + " %54, %55," + " p, %57, %58;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x104x64_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x104x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x104x64_F32E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[52]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %59, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n104k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51}," + "{%52, %53, %54, %55}," + " %56," + " %57, %58," + " p, %60, %61;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x104x64_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x112x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x112x64_F16E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[28]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %32, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n112k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27}," + " %28," + " %29," + " %30, %31," + " p, %33, %34;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x112x64_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x112x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x112x64_F16E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[28]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %35, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n112k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27}," + "{%28, %29, %30, %31}," + " %32," + " %33, %34," + " p, %36, %37;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x112x64_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x112x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x112x64_F32E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[56]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %60, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n112k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + " %56," + " %57," + " %58, %59," + " p, %61, %62;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x112x64_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x112x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x112x64_F32E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[56]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %63, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n112k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + "{%56, %57, %58, %59}," + " %60," + " %61, %62," + " p, %64, %65;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x112x64_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x120x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x120x64_F16E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[30]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %34, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n120k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29}," + " %30," + " %31," + " %32, %33," + " p, %35, %36;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x120x64_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x120x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x120x64_F16E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[30]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %37, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n120k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29}," + "{%30, %31, %32, %33}," + " %34," + " %35, %36," + " p, %38, %39;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x120x64_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x120x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x120x64_F32E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[60]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %64, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n120k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59}," + " %60," + " %61," + " %62, %63," + " p, %65, %66;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x120x64_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x120x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x120x64_F32E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[60]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %67, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n120k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59}," + "{%60, %61, %62, %63}," + " %64," + " %65, %66," + " p, %68, %69;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x120x64_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x136x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x136x64_F16E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[34]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %38, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n136k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33}," + " %34," + " %35," + " %36, %37," + " p, %39, %40;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x136x64_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x136x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x136x64_F16E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[34]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %41, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n136k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33}," + "{%34, %35, %36, %37}," + " %38," + " %39, %40," + " p, %42, %43;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x136x64_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x136x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x136x64_F32E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[68]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %72, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n136k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67}," + " %68," + " %69," + " %70, %71," + " p, %73, %74;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x136x64_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x136x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x136x64_F32E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[68]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %75, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n136k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67}," + "{%68, %69, %70, %71}," + " %72," + " %73, %74," + " p, %76, %77;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x136x64_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x144x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x144x64_F16E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[36]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %40, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n144k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35}," + " %36," + " %37," + " %38, %39," + " p, %41, %42;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x144x64_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x144x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x144x64_F16E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[36]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %43, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n144k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35}," + "{%36, %37, %38, %39}," + " %40," + " %41, %42," + " p, %44, %45;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x144x64_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x144x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x144x64_F32E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[72]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %76, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n144k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}," + " %72," + " %73," + " %74, %75," + " p, %77, %78;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x144x64_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x144x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x144x64_F32E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[72]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %79, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n144k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}," + "{%72, %73, %74, %75}," + " %76," + " %77, %78," + " p, %80, %81;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x144x64_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x152x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x152x64_F16E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[38]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %42, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n152k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37}," + " %38," + " %39," + " %40, %41," + " p, %43, %44;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x152x64_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x152x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x152x64_F16E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[38]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %45, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n152k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37}," + "{%38, %39, %40, %41}," + " %42," + " %43, %44," + " p, %46, %47;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x152x64_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x152x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x152x64_F32E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[76]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %80, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n152k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75}," + " %76," + " %77," + " %78, %79," + " p, %81, %82;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x152x64_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x152x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x152x64_F32E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[76]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %83, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n152k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75}," + "{%76, %77, %78, %79}," + " %80," + " %81, %82," + " p, %84, %85;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x152x64_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x160x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x160x64_F16E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[40]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %44, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n160k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + " %40," + " %41," + " %42, %43," + " p, %45, %46;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x160x64_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x160x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x160x64_F16E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[40]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %47, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n160k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + "{%40, %41, %42, %43}," + " %44," + " %45, %46," + " p, %48, %49;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x160x64_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x160x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x160x64_F32E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[80]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %84, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n160k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}," + " %80," + " %81," + " %82, %83," + " p, %85, %86;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x160x64_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x160x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x160x64_F32E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[80]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %87, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n160k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}," + "{%80, %81, %82, %83}," + " %84," + " %85, %86," + " p, %88, %89;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x160x64_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x168x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x168x64_F16E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[42]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %46, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n168k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41}," + " %42," + " %43," + " %44, %45," + " p, %47, %48;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x168x64_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x168x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x168x64_F16E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[42]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %49, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n168k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41}," + "{%42, %43, %44, %45}," + " %46," + " %47, %48," + " p, %50, %51;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x168x64_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x168x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x168x64_F32E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[84]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %88, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n168k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83}," + " %84," + " %85," + " %86, %87," + " p, %89, %90;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x168x64_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x168x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x168x64_F32E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[84]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %91, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n168k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83}," + "{%84, %85, %86, %87}," + " %88," + " %89, %90," + " p, %92, %93;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x168x64_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x176x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x176x64_F16E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[44]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %48, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n176k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43}," + " %44," + " %45," + " %46, %47," + " p, %49, %50;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x176x64_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x176x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x176x64_F16E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[44]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %51, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n176k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43}," + "{%44, %45, %46, %47}," + " %48," + " %49, %50," + " p, %52, %53;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x176x64_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x176x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x176x64_F32E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[88]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %92, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n176k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87}," + " %88," + " %89," + " %90, %91," + " p, %93, %94;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x176x64_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x176x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x176x64_F32E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[88]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %95, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n176k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87}," + "{%88, %89, %90, %91}," + " %92," + " %93, %94," + " p, %96, %97;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x176x64_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x184x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x184x64_F16E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[46]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %50, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n184k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45}," + " %46," + " %47," + " %48, %49," + " p, %51, %52;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x184x64_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x184x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x184x64_F16E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[46]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %53, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n184k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45}," + "{%46, %47, %48, %49}," + " %50," + " %51, %52," + " p, %54, %55;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x184x64_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x184x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x184x64_F32E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[92]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + float & d88, float & d89, float & d90, float & d91, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %96, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n184k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91}," + " %92," + " %93," + " %94, %95," + " p, %97, %98;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), + "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x184x64_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x184x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x184x64_F32E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[92]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + float & d88, float & d89, float & d90, float & d91, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %99, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n184k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91}," + "{%92, %93, %94, %95}," + " %96," + " %97, %98," + " p, %100, %101;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), + "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x184x64_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x200x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x200x64_F16E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[50]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %54, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n200k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49}," + " %50," + " %51," + " %52, %53," + " p, %55, %56;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x200x64_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x200x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x200x64_F16E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[50]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %57, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n200k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49}," + "{%50, %51, %52, %53}," + " %54," + " %55, %56," + " p, %58, %59;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x200x64_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x200x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x200x64_F32E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[100]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %104, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n200k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99}," + " %100," + " %101," + " %102, %103," + " p, %105, %106;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x200x64_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x200x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x200x64_F32E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[100]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %107, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n200k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99}," + "{%100, %101, %102, %103}," + " %104," + " %105, %106," + " p, %108, %109;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x200x64_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x208x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x208x64_F16E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[52]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %56, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n208k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51}," + " %52," + " %53," + " %54, %55," + " p, %57, %58;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x208x64_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x208x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x208x64_F16E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[52]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %59, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n208k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51}," + "{%52, %53, %54, %55}," + " %56," + " %57, %58," + " p, %60, %61;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x208x64_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x208x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x208x64_F32E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[104]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %108, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n208k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103}," + " %104," + " %105," + " %106, %107," + " p, %109, %110;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x208x64_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x208x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x208x64_F32E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[104]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %111, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n208k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103}," + "{%104, %105, %106, %107}," + " %108," + " %109, %110," + " p, %112, %113;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x208x64_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x216x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x216x64_F16E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[54]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %58, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n216k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53}," + " %54," + " %55," + " %56, %57," + " p, %59, %60;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x216x64_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x216x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x216x64_F16E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[54]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %61, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n216k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53}," + "{%54, %55, %56, %57}," + " %58," + " %59, %60," + " p, %62, %63;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x216x64_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x216x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x216x64_F32E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[108]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %112, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n216k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107}," + " %108," + " %109," + " %110, %111," + " p, %113, %114;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x216x64_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x216x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x216x64_F32E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[108]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %115, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n216k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107}," + "{%108, %109, %110, %111}," + " %112," + " %113, %114," + " p, %116, %117;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x216x64_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x224x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x224x64_F16E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[56]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %60, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n224k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + " %56," + " %57," + " %58, %59," + " p, %61, %62;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x224x64_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x224x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x224x64_F16E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[56]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %63, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n224k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + "{%56, %57, %58, %59}," + " %60," + " %61, %62," + " p, %64, %65;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x224x64_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x224x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x224x64_F32E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[112]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %116, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n224k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111}," + " %112," + " %113," + " %114, %115," + " p, %117, %118;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x224x64_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x224x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x224x64_F32E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[112]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %119, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n224k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111}," + "{%112, %113, %114, %115}," + " %116," + " %117, %118," + " p, %120, %121;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x224x64_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x232x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x232x64_F16E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[58]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %62, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n232k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57}," + " %58," + " %59," + " %60, %61," + " p, %63, %64;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x232x64_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x232x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x232x64_F16E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[58]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %65, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n232k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57}," + "{%58, %59, %60, %61}," + " %62," + " %63, %64," + " p, %66, %67;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x232x64_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x232x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x232x64_F32E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[116]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %120, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n232k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115}," + " %116," + " %117," + " %118, %119," + " p, %121, %122;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x232x64_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x232x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x232x64_F32E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[116]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %123, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n232k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115}," + "{%116, %117, %118, %119}," + " %120," + " %121, %122," + " p, %124, %125;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x232x64_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x240x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x240x64_F16E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[60]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %64, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n240k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59}," + " %60," + " %61," + " %62, %63," + " p, %65, %66;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x240x64_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x240x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x240x64_F16E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[60]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %67, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n240k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59}," + "{%60, %61, %62, %63}," + " %64," + " %65, %66," + " p, %68, %69;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x240x64_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x240x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x240x64_F32E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[120]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %124, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n240k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119}," + " %120," + " %121," + " %122, %123," + " p, %125, %126;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x240x64_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x240x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x240x64_F32E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[120]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %127, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n240k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119}," + "{%120, %121, %122, %123}," + " %124," + " %125, %126," + " p, %128, %129;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x240x64_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x248x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x248x64_F16E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[62]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %66, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n248k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61}," + " %62," + " %63," + " %64, %65," + " p, %67, %68;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x248x64_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x248x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x248x64_F16E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[62]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %69, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n248k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61}," + "{%62, %63, %64, %65}," + " %66," + " %67, %68," + " p, %70, %71;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x248x64_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x248x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x248x64_F32E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[124]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + float & d120, float & d121, float & d122, float & d123, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %128, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n248k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123}," + " %124," + " %125," + " %126, %127," + " p, %129, %130;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), + "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x248x64_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x248x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x248x64_F32E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[124]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + float & d120, float & d121, float & d122, float & d123, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %131, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n248k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123}," + "{%124, %125, %126, %127}," + " %128," + " %129, %130," + " p, %132, %133;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), + "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x248x64_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x24x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x24x64_F16E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[6]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %10, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n24k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5}," + " %6," + " %7," + " %8, %9," + " p, %11, %12;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x24x64_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x24x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x24x64_F16E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[6]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %13, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n24k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5}," + "{%6, %7, %8, %9}," + " %10," + " %11, %12," + " p, %14, %15;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x24x64_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x24x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x24x64_F32E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[12]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %16, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n24k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + " %12," + " %13," + " %14, %15," + " p, %17, %18;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x24x64_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x24x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x24x64_F32E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[12]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %19, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n24k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + "{%12, %13, %14, %15}," + " %16," + " %17, %18," + " p, %20, %21;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x24x64_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x40x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x40x64_F16E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[10]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %14, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n40k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9}," + " %10," + " %11," + " %12, %13," + " p, %15, %16;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x40x64_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x40x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x40x64_F16E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[10]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %17, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n40k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9}," + "{%10, %11, %12, %13}," + " %14," + " %15, %16," + " p, %18, %19;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x40x64_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x40x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x40x64_F32E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[20]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %24, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n40k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19}," + " %20," + " %21," + " %22, %23," + " p, %25, %26;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x40x64_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x40x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x40x64_F32E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[20]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %27, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n40k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19}," + "{%20, %21, %22, %23}," + " %24," + " %25, %26," + " p, %28, %29;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x40x64_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x48x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x48x64_F16E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[12]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %16, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n48k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + " %12," + " %13," + " %14, %15," + " p, %17, %18;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x48x64_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x48x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x48x64_F16E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[12]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %19, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n48k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + "{%12, %13, %14, %15}," + " %16," + " %17, %18," + " p, %20, %21;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x48x64_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x48x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x48x64_F32E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[24]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %28, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n48k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + " %24," + " %25," + " %26, %27," + " p, %29, %30;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x48x64_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x48x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x48x64_F32E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[24]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %31, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n48k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + "{%24, %25, %26, %27}," + " %28," + " %29, %30," + " p, %32, %33;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x48x64_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x56x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x56x64_F16E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[14]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %18, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n56k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13}," + " %14," + " %15," + " %16, %17," + " p, %19, %20;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x56x64_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x56x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x56x64_F16E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[14]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %21, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n56k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13}," + "{%14, %15, %16, %17}," + " %18," + " %19, %20," + " p, %22, %23;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x56x64_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x56x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x56x64_F32E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[28]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %32, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n56k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27}," + " %28," + " %29," + " %30, %31," + " p, %33, %34;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x56x64_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x56x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x56x64_F32E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[28]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %35, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n56k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27}," + "{%28, %29, %30, %31}," + " %32," + " %33, %34," + " p, %36, %37;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x56x64_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x72x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x72x64_F16E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[18]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %22, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n72k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17}," + " %18," + " %19," + " %20, %21," + " p, %23, %24;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x72x64_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x72x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x72x64_F16E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[18]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %25, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n72k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17}," + "{%18, %19, %20, %21}," + " %22," + " %23, %24," + " p, %26, %27;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x72x64_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x72x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x72x64_F32E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[36]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %40, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n72k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35}," + " %36," + " %37," + " %38, %39," + " p, %41, %42;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x72x64_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x72x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x72x64_F32E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[36]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %43, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n72k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35}," + "{%36, %37, %38, %39}," + " %40," + " %41, %42," + " p, %44, %45;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x72x64_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x80x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x80x64_F16E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[20]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %24, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n80k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19}," + " %20," + " %21," + " %22, %23," + " p, %25, %26;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x80x64_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x80x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x80x64_F16E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[20]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %27, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n80k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19}," + "{%20, %21, %22, %23}," + " %24," + " %25, %26," + " p, %28, %29;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x80x64_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x80x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x80x64_F32E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[40]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %44, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n80k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + " %40," + " %41," + " %42, %43," + " p, %45, %46;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x80x64_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x80x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x80x64_F32E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[40]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %47, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n80k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + "{%40, %41, %42, %43}," + " %44," + " %45, %46," + " p, %48, %49;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x80x64_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x88x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x88x64_F16E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[22]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %26, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n88k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21}," + " %22," + " %23," + " %24, %25," + " p, %27, %28;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x88x64_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x88x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x88x64_F16E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[22]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %29, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n88k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21}," + "{%22, %23, %24, %25}," + " %26," + " %27, %28," + " p, %30, %31;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x88x64_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x88x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x88x64_F32E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[44]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %48, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n88k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43}," + " %44," + " %45," + " %46, %47," + " p, %49, %50;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x88x64_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x88x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x88x64_F32E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[44]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %51, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n88k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43}," + "{%44, %45, %46, %47}," + " %48," + " %49, %50," + " p, %52, %53;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x88x64_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x104x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x104x64_F16E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[26]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %30, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n104k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25}," + " %26," + " %27," + " %28, %29," + " p, %31, %32;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x104x64_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x104x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x104x64_F16E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[26]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %33, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n104k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25}," + "{%26, %27, %28, %29}," + " %30," + " %31, %32," + " p, %34, %35;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x104x64_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x104x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x104x64_F32E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[52]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %56, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n104k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51}," + " %52," + " %53," + " %54, %55," + " p, %57, %58;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x104x64_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x104x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x104x64_F32E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[52]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %59, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n104k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51}," + "{%52, %53, %54, %55}," + " %56," + " %57, %58," + " p, %60, %61;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x104x64_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x112x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x112x64_F16E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[28]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %32, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n112k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27}," + " %28," + " %29," + " %30, %31," + " p, %33, %34;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x112x64_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x112x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x112x64_F16E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[28]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %35, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n112k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27}," + "{%28, %29, %30, %31}," + " %32," + " %33, %34," + " p, %36, %37;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x112x64_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x112x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x112x64_F32E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[56]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %60, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n112k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + " %56," + " %57," + " %58, %59," + " p, %61, %62;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x112x64_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x112x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x112x64_F32E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[56]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %63, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n112k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + "{%56, %57, %58, %59}," + " %60," + " %61, %62," + " p, %64, %65;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x112x64_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x120x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x120x64_F16E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[30]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %34, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n120k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29}," + " %30," + " %31," + " %32, %33," + " p, %35, %36;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x120x64_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x120x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x120x64_F16E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[30]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %37, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n120k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29}," + "{%30, %31, %32, %33}," + " %34," + " %35, %36," + " p, %38, %39;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x120x64_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x120x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x120x64_F32E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[60]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %64, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n120k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59}," + " %60," + " %61," + " %62, %63," + " p, %65, %66;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x120x64_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x120x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x120x64_F32E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[60]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %67, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n120k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59}," + "{%60, %61, %62, %63}," + " %64," + " %65, %66," + " p, %68, %69;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x120x64_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x136x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x136x64_F16E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[34]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %38, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n136k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33}," + " %34," + " %35," + " %36, %37," + " p, %39, %40;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x136x64_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x136x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x136x64_F16E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[34]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %41, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n136k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33}," + "{%34, %35, %36, %37}," + " %38," + " %39, %40," + " p, %42, %43;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x136x64_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x136x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x136x64_F32E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[68]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %72, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n136k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67}," + " %68," + " %69," + " %70, %71," + " p, %73, %74;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x136x64_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x136x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x136x64_F32E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[68]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %75, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n136k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67}," + "{%68, %69, %70, %71}," + " %72," + " %73, %74," + " p, %76, %77;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x136x64_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x144x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x144x64_F16E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[36]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %40, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n144k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35}," + " %36," + " %37," + " %38, %39," + " p, %41, %42;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x144x64_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x144x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x144x64_F16E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[36]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %43, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n144k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35}," + "{%36, %37, %38, %39}," + " %40," + " %41, %42," + " p, %44, %45;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x144x64_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x144x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x144x64_F32E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[72]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %76, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n144k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}," + " %72," + " %73," + " %74, %75," + " p, %77, %78;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x144x64_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x144x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x144x64_F32E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[72]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %79, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n144k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}," + "{%72, %73, %74, %75}," + " %76," + " %77, %78," + " p, %80, %81;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x144x64_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x152x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x152x64_F16E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[38]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %42, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n152k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37}," + " %38," + " %39," + " %40, %41," + " p, %43, %44;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x152x64_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x152x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x152x64_F16E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[38]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %45, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n152k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37}," + "{%38, %39, %40, %41}," + " %42," + " %43, %44," + " p, %46, %47;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x152x64_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x152x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x152x64_F32E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[76]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %80, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n152k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75}," + " %76," + " %77," + " %78, %79," + " p, %81, %82;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x152x64_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x152x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x152x64_F32E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[76]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %83, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n152k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75}," + "{%76, %77, %78, %79}," + " %80," + " %81, %82," + " p, %84, %85;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x152x64_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x160x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x160x64_F16E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[40]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %44, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n160k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + " %40," + " %41," + " %42, %43," + " p, %45, %46;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x160x64_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x160x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x160x64_F16E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[40]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %47, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n160k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + "{%40, %41, %42, %43}," + " %44," + " %45, %46," + " p, %48, %49;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x160x64_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x160x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x160x64_F32E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[80]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %84, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n160k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}," + " %80," + " %81," + " %82, %83," + " p, %85, %86;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x160x64_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x160x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x160x64_F32E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[80]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %87, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n160k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}," + "{%80, %81, %82, %83}," + " %84," + " %85, %86," + " p, %88, %89;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x160x64_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x168x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x168x64_F16E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[42]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %46, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n168k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41}," + " %42," + " %43," + " %44, %45," + " p, %47, %48;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x168x64_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x168x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x168x64_F16E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[42]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %49, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n168k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41}," + "{%42, %43, %44, %45}," + " %46," + " %47, %48," + " p, %50, %51;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x168x64_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x168x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x168x64_F32E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[84]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %88, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n168k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83}," + " %84," + " %85," + " %86, %87," + " p, %89, %90;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x168x64_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x168x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x168x64_F32E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[84]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %91, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n168k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83}," + "{%84, %85, %86, %87}," + " %88," + " %89, %90," + " p, %92, %93;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x168x64_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x176x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x176x64_F16E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[44]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %48, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n176k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43}," + " %44," + " %45," + " %46, %47," + " p, %49, %50;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x176x64_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x176x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x176x64_F16E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[44]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %51, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n176k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43}," + "{%44, %45, %46, %47}," + " %48," + " %49, %50," + " p, %52, %53;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x176x64_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x176x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x176x64_F32E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[88]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %92, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n176k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87}," + " %88," + " %89," + " %90, %91," + " p, %93, %94;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x176x64_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x176x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x176x64_F32E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[88]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %95, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n176k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87}," + "{%88, %89, %90, %91}," + " %92," + " %93, %94," + " p, %96, %97;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x176x64_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x184x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x184x64_F16E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[46]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %50, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n184k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45}," + " %46," + " %47," + " %48, %49," + " p, %51, %52;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x184x64_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x184x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x184x64_F16E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[46]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %53, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n184k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45}," + "{%46, %47, %48, %49}," + " %50," + " %51, %52," + " p, %54, %55;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x184x64_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x184x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x184x64_F32E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[92]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + float & d88, float & d89, float & d90, float & d91, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %96, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n184k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91}," + " %92," + " %93," + " %94, %95," + " p, %97, %98;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), + "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x184x64_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x184x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x184x64_F32E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[92]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + float & d88, float & d89, float & d90, float & d91, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %99, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n184k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91}," + "{%92, %93, %94, %95}," + " %96," + " %97, %98," + " p, %100, %101;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), + "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x184x64_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x200x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x200x64_F16E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[50]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %54, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n200k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49}," + " %50," + " %51," + " %52, %53," + " p, %55, %56;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x200x64_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x200x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x200x64_F16E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[50]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %57, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n200k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49}," + "{%50, %51, %52, %53}," + " %54," + " %55, %56," + " p, %58, %59;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x200x64_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x200x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x200x64_F32E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[100]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %104, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n200k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99}," + " %100," + " %101," + " %102, %103," + " p, %105, %106;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x200x64_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x200x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x200x64_F32E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[100]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %107, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n200k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99}," + "{%100, %101, %102, %103}," + " %104," + " %105, %106," + " p, %108, %109;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x200x64_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x208x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x208x64_F16E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[52]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %56, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n208k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51}," + " %52," + " %53," + " %54, %55," + " p, %57, %58;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x208x64_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x208x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x208x64_F16E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[52]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %59, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n208k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51}," + "{%52, %53, %54, %55}," + " %56," + " %57, %58," + " p, %60, %61;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x208x64_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x208x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x208x64_F32E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[104]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %108, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n208k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103}," + " %104," + " %105," + " %106, %107," + " p, %109, %110;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x208x64_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x208x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x208x64_F32E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[104]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %111, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n208k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103}," + "{%104, %105, %106, %107}," + " %108," + " %109, %110," + " p, %112, %113;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x208x64_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x216x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x216x64_F16E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[54]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %58, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n216k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53}," + " %54," + " %55," + " %56, %57," + " p, %59, %60;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x216x64_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x216x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x216x64_F16E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[54]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %61, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n216k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53}," + "{%54, %55, %56, %57}," + " %58," + " %59, %60," + " p, %62, %63;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x216x64_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x216x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x216x64_F32E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[108]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %112, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n216k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107}," + " %108," + " %109," + " %110, %111," + " p, %113, %114;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x216x64_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x216x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x216x64_F32E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[108]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %115, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n216k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107}," + "{%108, %109, %110, %111}," + " %112," + " %113, %114," + " p, %116, %117;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x216x64_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x224x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x224x64_F16E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[56]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %60, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n224k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + " %56," + " %57," + " %58, %59," + " p, %61, %62;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x224x64_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x224x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x224x64_F16E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[56]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %63, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n224k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + "{%56, %57, %58, %59}," + " %60," + " %61, %62," + " p, %64, %65;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x224x64_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x224x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x224x64_F32E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[112]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %116, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n224k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111}," + " %112," + " %113," + " %114, %115," + " p, %117, %118;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x224x64_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x224x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x224x64_F32E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[112]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %119, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n224k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111}," + "{%112, %113, %114, %115}," + " %116," + " %117, %118," + " p, %120, %121;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x224x64_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x232x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x232x64_F16E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[58]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %62, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n232k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57}," + " %58," + " %59," + " %60, %61," + " p, %63, %64;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x232x64_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x232x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x232x64_F16E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[58]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %65, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n232k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57}," + "{%58, %59, %60, %61}," + " %62," + " %63, %64," + " p, %66, %67;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x232x64_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x232x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x232x64_F32E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[116]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %120, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n232k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115}," + " %116," + " %117," + " %118, %119," + " p, %121, %122;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x232x64_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x232x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x232x64_F32E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[116]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %123, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n232k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115}," + "{%116, %117, %118, %119}," + " %120," + " %121, %122," + " p, %124, %125;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x232x64_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x240x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x240x64_F16E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[60]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %64, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n240k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59}," + " %60," + " %61," + " %62, %63," + " p, %65, %66;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x240x64_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x240x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x240x64_F16E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[60]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %67, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n240k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59}," + "{%60, %61, %62, %63}," + " %64," + " %65, %66," + " p, %68, %69;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x240x64_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x240x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x240x64_F32E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[120]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %124, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n240k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119}," + " %120," + " %121," + " %122, %123," + " p, %125, %126;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x240x64_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x240x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x240x64_F32E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[120]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %127, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n240k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119}," + "{%120, %121, %122, %123}," + " %124," + " %125, %126," + " p, %128, %129;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x240x64_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x248x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x248x64_F16E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[62]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %66, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n248k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61}," + " %62," + " %63," + " %64, %65," + " p, %67, %68;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x248x64_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x248x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x248x64_F16E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[62]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %69, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n248k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61}," + "{%62, %63, %64, %65}," + " %66," + " %67, %68," + " p, %70, %71;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x248x64_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x248x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x248x64_F32E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[124]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + float & d120, float & d121, float & d122, float & d123, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %128, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n248k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123}," + " %124," + " %125," + " %126, %127," + " p, %129, %130;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), + "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x248x64_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x248x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x248x64_F32E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[124]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + float & d120, float & d121, float & d122, float & d123, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %131, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n248k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123}," + "{%124, %125, %126, %127}," + " %128," + " %129, %130," + " p, %132, %133;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), + "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x248x64_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace SM90::GMMA::SPARSE + +} // namespace cute diff --git a/include/cute/atom/mma_traits_sm90_gmma.hpp b/include/cute/atom/mma_traits_sm90_gmma.hpp index 03016fa90..b02f5b3af 100644 --- a/include/cute/atom/mma_traits_sm90_gmma.hpp +++ b/include/cute/atom/mma_traits_sm90_gmma.hpp @@ -441,66 +441,26 @@ mma_unpack(MMA_Traits const& traits, } // Accumulator layouts -using CLayout_64x8 = Layout,Shape < _2,_2>>, - Stride,Stride<_64,_8>>>; - -using CLayout_64x16 = Layout,Shape < _2,_2, _2>>, - Stride,Stride<_64,_8,_512>>>; - -using CLayout_64x32 = Layout,Shape < _2,_2, _4>>, - Stride,Stride<_64,_8,_512>>>; - -using CLayout_64x40 = Layout,Shape < _2,_2, _5>>, - Stride,Stride<_64,_8,_512>>>; - -using CLayout_64x48 = Layout,Shape < _2,_2, _6>>, - Stride,Stride<_64,_8,_512>>>; - -using CLayout_64x64 = Layout,Shape < _2,_2, _8>>, - Stride,Stride<_64,_8,_512>>>; - -using CLayout_64x80 = Layout,Shape < _2,_2, _10>>, - Stride,Stride<_64,_8,_512>>>; - -using CLayout_64x96 = Layout,Shape < _2,_2, _12>>, - Stride,Stride<_64,_8,_512>>>; - -using CLayout_64x112 = Layout,Shape < _2,_2, Int<14>>>, - Stride,Stride<_64,_8,_512>>>; - -using CLayout_64x128 = Layout,Shape < _2,_2, _16>>, - Stride,Stride<_64,_8,_512>>>; - -using CLayout_64x144 = Layout,Shape < _2,_2, Int<18>>>, - Stride,Stride<_64,_8,_512>>>; - -using CLayout_64x160 = Layout,Shape < _2,_2, Int<20>>>, - Stride,Stride<_64,_8,_512>>>; - -using CLayout_64x176 = Layout,Shape < _2,_2, Int<22>>>, - Stride,Stride<_64,_8,_512>>>; - -using CLayout_64x192 = Layout,Shape < _2,_2, _24>>, - Stride,Stride<_64,_8,_512>>>; - -using CLayout_64x208 = Layout,Shape < _2,_2, Int<26>>>, - Stride,Stride<_64,_8,_512>>>; - -using CLayout_64x224 = Layout,Shape < _2,_2, Int<28>>>, - Stride,Stride<_64,_8,_512>>>; - -using CLayout_64x240 = Layout,Shape < _2,_2, Int<30>>>, - Stride,Stride<_64,_8,_512>>>; - -using CLayout_64x256 = Layout,Shape < _2,_2, _32>>, - Stride,Stride<_64,_8,_512>>>; +template +using CLayout_64xN = Layout,Shape < _2,_2,Int>>, + Stride,Stride<_64,_8, _512>>>; + +using CLayout_64x8 = CLayout_64xN< 8>; +using CLayout_64x16 = CLayout_64xN< 16>; +using CLayout_64x32 = CLayout_64xN< 32>; +using CLayout_64x64 = CLayout_64xN< 64>; +using CLayout_64x96 = CLayout_64xN< 96>; +using CLayout_64x128 = CLayout_64xN<128>; +using CLayout_64x192 = CLayout_64xN<192>; +using CLayout_64x256 = CLayout_64xN<256>; // Register source layout for 32-bit value types using ALayout_64x8 = Layout,Shape < _2, _2>>, Stride,Stride< _8,_256>>>; // Register source layout for 16-bit (sparse 32-bit) value types -using ALayout_64x16 = CLayout_64x16; +using ALayout_64x16 = Layout,Shape < _2,_2, _2>>, + Stride,Stride<_64,_8,_512>>>; // Register source layout for 8-bit (sparse 16-bit) value types using ALayout_64x32 = Layout,Shape < _4,_2, _2>>, @@ -549,7 +509,6 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// - template < GMMA::Major tnspA, GMMA::Major tnspB, @@ -579,7 +538,6 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// - template < GMMA::Major tnspA, GMMA::Major tnspB, @@ -610,7 +568,6 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// - template < GMMA::Major tnspA, GMMA::Major tnspB, @@ -640,7 +597,6 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// - template < GMMA::Major tnspA, GMMA::Major tnspB, @@ -671,7 +627,6 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// - template < GMMA::Major tnspA, GMMA::Major tnspB, @@ -701,18 +656,16 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) - template < GMMA::Major tnspA, GMMA::Major tnspB, GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -using SM90_64x48x16_F16F16F16_SS = SM90::GMMA::MMA_64x48x16_F16F16F16_SS; +using SM90_64x64x16_F16F16F16_SS = SM90::GMMA::MMA_64x64x16_F16F16F16_SS; template -struct MMA_Traits> +struct MMA_Traits> { using ValTypeD = half_t; using ValTypeA = half_t; @@ -722,30 +675,27 @@ struct MMA_Traits> using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_48,_16>; + using Shape_MNK = Shape<_64,_64,_16>; using ThrID = Layout<_128>; using ALayout = GMMA::ABLayout< 64, 16>; - using BLayout = GMMA::ABLayout< 48, 16>; - using CLayout = GMMA::CLayout_64x48; + using BLayout = GMMA::ABLayout< 64, 16>; + using CLayout = GMMA::CLayout_64x64; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) - template < GMMA::Major tnspA, GMMA::Major tnspB, GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -using SM90_64x48x16_F16F16F16_RS = SM90::GMMA::MMA_64x48x16_F16F16F16_RS; +using SM90_64x64x16_F16F16F16_RS = SM90::GMMA::MMA_64x64x16_F16F16F16_RS; template -struct MMA_Traits> +struct MMA_Traits> { using ValTypeD = half_t; using ValTypeA = half_t; @@ -754,29 +704,27 @@ struct MMA_Traits> using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_48,_16>; + using Shape_MNK = Shape<_64,_64,_16>; using ThrID = Layout<_128>; using ALayout = GMMA::ALayout_64x16; - using BLayout = GMMA::ABLayout< 48, 16>; - using CLayout = GMMA::CLayout_64x48; + using BLayout = GMMA::ABLayout< 64, 16>; + using CLayout = GMMA::CLayout_64x64; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// - template < GMMA::Major tnspA, GMMA::Major tnspB, GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -using SM90_64x64x16_F16F16F16_SS = SM90::GMMA::MMA_64x64x16_F16F16F16_SS; +using SM90_64x96x16_F16F16F16_SS = SM90::GMMA::MMA_64x96x16_F16F16F16_SS; template -struct MMA_Traits> +struct MMA_Traits> { using ValTypeD = half_t; using ValTypeA = half_t; @@ -786,28 +734,27 @@ struct MMA_Traits> using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_64,_16>; + using Shape_MNK = Shape<_64,_96,_16>; using ThrID = Layout<_128>; using ALayout = GMMA::ABLayout< 64, 16>; - using BLayout = GMMA::ABLayout< 64, 16>; - using CLayout = GMMA::CLayout_64x64; + using BLayout = GMMA::ABLayout< 96, 16>; + using CLayout = GMMA::CLayout_64x96; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; //////////////////////////////////////////////////////////////////////////////////////////////////// - template < GMMA::Major tnspA, GMMA::Major tnspB, GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -using SM90_64x64x16_F16F16F16_RS = SM90::GMMA::MMA_64x64x16_F16F16F16_RS; +using SM90_64x96x16_F16F16F16_RS = SM90::GMMA::MMA_64x96x16_F16F16F16_RS; template -struct MMA_Traits> +struct MMA_Traits> { using ValTypeD = half_t; using ValTypeA = half_t; @@ -816,29 +763,27 @@ struct MMA_Traits> using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_64,_16>; + using Shape_MNK = Shape<_64,_96,_16>; using ThrID = Layout<_128>; using ALayout = GMMA::ALayout_64x16; - using BLayout = GMMA::ABLayout< 64, 16>; - using CLayout = GMMA::CLayout_64x64; + using BLayout = GMMA::ABLayout< 96, 16>; + using CLayout = GMMA::CLayout_64x96; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) - template < GMMA::Major tnspA, GMMA::Major tnspB, GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -using SM90_64x80x16_F16F16F16_SS = SM90::GMMA::MMA_64x80x16_F16F16F16_SS; +using SM90_64x128x16_F16F16F16_SS = SM90::GMMA::MMA_64x128x16_F16F16F16_SS; template -struct MMA_Traits> +struct MMA_Traits> { using ValTypeD = half_t; using ValTypeA = half_t; @@ -848,30 +793,27 @@ struct MMA_Traits> using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_80,_16>; + using Shape_MNK = Shape<_64,_128,_16>; using ThrID = Layout<_128>; using ALayout = GMMA::ABLayout< 64, 16>; - using BLayout = GMMA::ABLayout< 80, 16>; - using CLayout = GMMA::CLayout_64x80; + using BLayout = GMMA::ABLayout<128, 16>; + using CLayout = GMMA::CLayout_64x128; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) - template < GMMA::Major tnspA, GMMA::Major tnspB, GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -using SM90_64x80x16_F16F16F16_RS = SM90::GMMA::MMA_64x80x16_F16F16F16_RS; +using SM90_64x128x16_F16F16F16_RS = SM90::GMMA::MMA_64x128x16_F16F16F16_RS; template -struct MMA_Traits> +struct MMA_Traits> { using ValTypeD = half_t; using ValTypeA = half_t; @@ -880,29 +822,27 @@ struct MMA_Traits> using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_80,_16>; + using Shape_MNK = Shape<_64,_128,_16>; using ThrID = Layout<_128>; using ALayout = GMMA::ALayout_64x16; - using BLayout = GMMA::ABLayout< 80, 16>; - using CLayout = GMMA::CLayout_64x80; + using BLayout = GMMA::ABLayout<128, 16>; + using CLayout = GMMA::CLayout_64x128; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// - template < GMMA::Major tnspA, GMMA::Major tnspB, GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -using SM90_64x96x16_F16F16F16_SS = SM90::GMMA::MMA_64x96x16_F16F16F16_SS; +using SM90_64x192x16_F16F16F16_SS = SM90::GMMA::MMA_64x192x16_F16F16F16_SS; template -struct MMA_Traits> +struct MMA_Traits> { using ValTypeD = half_t; using ValTypeA = half_t; @@ -912,28 +852,27 @@ struct MMA_Traits> using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_96,_16>; + using Shape_MNK = Shape<_64,_192,_16>; using ThrID = Layout<_128>; using ALayout = GMMA::ABLayout< 64, 16>; - using BLayout = GMMA::ABLayout< 96, 16>; - using CLayout = GMMA::CLayout_64x96; + using BLayout = GMMA::ABLayout<192, 16>; + using CLayout = GMMA::CLayout_64x192; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; //////////////////////////////////////////////////////////////////////////////////////////////////// - template < GMMA::Major tnspA, GMMA::Major tnspB, GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -using SM90_64x96x16_F16F16F16_RS = SM90::GMMA::MMA_64x96x16_F16F16F16_RS; +using SM90_64x192x16_F16F16F16_RS = SM90::GMMA::MMA_64x192x16_F16F16F16_RS; template -struct MMA_Traits> +struct MMA_Traits> { using ValTypeD = half_t; using ValTypeA = half_t; @@ -942,29 +881,27 @@ struct MMA_Traits> using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_96,_16>; + using Shape_MNK = Shape<_64,_192,_16>; using ThrID = Layout<_128>; using ALayout = GMMA::ALayout_64x16; - using BLayout = GMMA::ABLayout< 96, 16>; - using CLayout = GMMA::CLayout_64x96; + using BLayout = GMMA::ABLayout<192, 16>; + using CLayout = GMMA::CLayout_64x192; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) - template < GMMA::Major tnspA, GMMA::Major tnspB, GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -using SM90_64x112x16_F16F16F16_SS = SM90::GMMA::MMA_64x112x16_F16F16F16_SS; +using SM90_64x256x16_F16F16F16_SS = SM90::GMMA::MMA_64x256x16_F16F16F16_SS; template -struct MMA_Traits> +struct MMA_Traits> { using ValTypeD = half_t; using ValTypeA = half_t; @@ -974,30 +911,27 @@ struct MMA_Traits> using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_112,_16>; + using Shape_MNK = Shape<_64,_256,_16>; using ThrID = Layout<_128>; using ALayout = GMMA::ABLayout< 64, 16>; - using BLayout = GMMA::ABLayout<112, 16>; - using CLayout = GMMA::CLayout_64x112; + using BLayout = GMMA::ABLayout<256, 16>; + using CLayout = GMMA::CLayout_64x256; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) - template < GMMA::Major tnspA, GMMA::Major tnspB, GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -using SM90_64x112x16_F16F16F16_RS = SM90::GMMA::MMA_64x112x16_F16F16F16_RS; +using SM90_64x256x16_F16F16F16_RS = SM90::GMMA::MMA_64x256x16_F16F16F16_RS; template -struct MMA_Traits> +struct MMA_Traits> { using ValTypeD = half_t; using ValTypeA = half_t; @@ -1006,210 +940,194 @@ struct MMA_Traits> using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_112,_16>; + using Shape_MNK = Shape<_64,_256,_16>; using ThrID = Layout<_128>; using ALayout = GMMA::ALayout_64x16; - using BLayout = GMMA::ABLayout<112, 16>; - using CLayout = GMMA::CLayout_64x112; + using BLayout = GMMA::ABLayout<256, 16>; + using CLayout = GMMA::CLayout_64x256; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// - template < GMMA::Major tnspA, GMMA::Major tnspB, GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -using SM90_64x128x16_F16F16F16_SS = SM90::GMMA::MMA_64x128x16_F16F16F16_SS; +using SM90_64x8x16_F32F16F16_SS = SM90::GMMA::MMA_64x8x16_F32F16F16_SS; template -struct MMA_Traits> +struct MMA_Traits> { - using ValTypeD = half_t; + using ValTypeD = float; using ValTypeA = half_t; using ValTypeB = half_t; - using ValTypeC = half_t; + using ValTypeC = float; using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_128,_16>; + using Shape_MNK = Shape<_64,_8,_16>; using ThrID = Layout<_128>; using ALayout = GMMA::ABLayout< 64, 16>; - using BLayout = GMMA::ABLayout<128, 16>; - using CLayout = GMMA::CLayout_64x128; + using BLayout = GMMA::ABLayout< 8, 16>; + using CLayout = GMMA::CLayout_64x8; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; //////////////////////////////////////////////////////////////////////////////////////////////////// - template < GMMA::Major tnspA, GMMA::Major tnspB, GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -using SM90_64x128x16_F16F16F16_RS = SM90::GMMA::MMA_64x128x16_F16F16F16_RS; +using SM90_64x8x16_F32F16F16_RS = SM90::GMMA::MMA_64x8x16_F32F16F16_RS; template -struct MMA_Traits> +struct MMA_Traits> { - using ValTypeD = half_t; + using ValTypeD = float; using ValTypeA = half_t; using ValTypeB = half_t; - using ValTypeC = half_t; + using ValTypeC = float; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_128,_16>; + using Shape_MNK = Shape<_64,_8,_16>; using ThrID = Layout<_128>; using ALayout = GMMA::ALayout_64x16; - using BLayout = GMMA::ABLayout<128, 16>; - using CLayout = GMMA::CLayout_64x128; + using BLayout = GMMA::ABLayout< 8, 16>; + using CLayout = GMMA::CLayout_64x8; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) - template < GMMA::Major tnspA, GMMA::Major tnspB, GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -using SM90_64x144x16_F16F16F16_SS = SM90::GMMA::MMA_64x144x16_F16F16F16_SS; +using SM90_64x16x16_F32F16F16_SS = SM90::GMMA::MMA_64x16x16_F32F16F16_SS; template -struct MMA_Traits> +struct MMA_Traits> { - using ValTypeD = half_t; + using ValTypeD = float; using ValTypeA = half_t; using ValTypeB = half_t; - using ValTypeC = half_t; + using ValTypeC = float; using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_144,_16>; + using Shape_MNK = Shape<_64,_16,_16>; using ThrID = Layout<_128>; using ALayout = GMMA::ABLayout< 64, 16>; - using BLayout = GMMA::ABLayout<144, 16>; - using CLayout = GMMA::CLayout_64x144; + using BLayout = GMMA::ABLayout< 16, 16>; + using CLayout = GMMA::CLayout_64x16; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) - template < GMMA::Major tnspA, GMMA::Major tnspB, GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -using SM90_64x144x16_F16F16F16_RS = SM90::GMMA::MMA_64x144x16_F16F16F16_RS; +using SM90_64x16x16_F32F16F16_RS = SM90::GMMA::MMA_64x16x16_F32F16F16_RS; template -struct MMA_Traits> +struct MMA_Traits> { - using ValTypeD = half_t; + using ValTypeD = float; using ValTypeA = half_t; using ValTypeB = half_t; - using ValTypeC = half_t; + using ValTypeC = float; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_144,_16>; + using Shape_MNK = Shape<_64,_16,_16>; using ThrID = Layout<_128>; using ALayout = GMMA::ALayout_64x16; - using BLayout = GMMA::ABLayout<144, 16>; - using CLayout = GMMA::CLayout_64x144; + using BLayout = GMMA::ABLayout< 16, 16>; + using CLayout = GMMA::CLayout_64x16; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) - template < GMMA::Major tnspA, GMMA::Major tnspB, GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -using SM90_64x160x16_F16F16F16_SS = SM90::GMMA::MMA_64x160x16_F16F16F16_SS; +using SM90_64x32x16_F32F16F16_SS = SM90::GMMA::MMA_64x32x16_F32F16F16_SS; template -struct MMA_Traits> +struct MMA_Traits> { - using ValTypeD = half_t; + using ValTypeD = float; using ValTypeA = half_t; using ValTypeB = half_t; - using ValTypeC = half_t; + using ValTypeC = float; using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_160,_16>; + using Shape_MNK = Shape<_64,_32,_16>; using ThrID = Layout<_128>; using ALayout = GMMA::ABLayout< 64, 16>; - using BLayout = GMMA::ABLayout<160, 16>; - using CLayout = GMMA::CLayout_64x160; + using BLayout = GMMA::ABLayout< 32, 16>; + using CLayout = GMMA::CLayout_64x32; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) - template < GMMA::Major tnspA, GMMA::Major tnspB, GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -using SM90_64x160x16_F16F16F16_RS = SM90::GMMA::MMA_64x160x16_F16F16F16_RS; +using SM90_64x32x16_F32F16F16_RS = SM90::GMMA::MMA_64x32x16_F32F16F16_RS; template -struct MMA_Traits> +struct MMA_Traits> { - using ValTypeD = half_t; + using ValTypeD = float; using ValTypeA = half_t; using ValTypeB = half_t; - using ValTypeC = half_t; + using ValTypeC = float; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_160,_16>; + using Shape_MNK = Shape<_64,_32,_16>; using ThrID = Layout<_128>; using ALayout = GMMA::ALayout_64x16; - using BLayout = GMMA::ABLayout<160, 16>; - using CLayout = GMMA::CLayout_64x160; + using BLayout = GMMA::ABLayout< 32, 16>; + using CLayout = GMMA::CLayout_64x32; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) template < GMMA::Major tnspA, @@ -1217,335 +1135,251 @@ template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -using SM90_64x176x16_F16F16F16_SS = SM90::GMMA::MMA_64x176x16_F16F16F16_SS; +using SM90_64x64x16_F32F16F16_SS = SM90::GMMA::MMA_64x64x16_F32F16F16_SS; template -struct MMA_Traits> +struct MMA_Traits> { - using ValTypeD = half_t; + using ValTypeD = float; using ValTypeA = half_t; using ValTypeB = half_t; - using ValTypeC = half_t; + using ValTypeC = float; using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_176,_16>; + using Shape_MNK = Shape<_64,_64,_16>; using ThrID = Layout<_128>; using ALayout = GMMA::ABLayout< 64, 16>; - using BLayout = GMMA::ABLayout<176, 16>; - using CLayout = GMMA::CLayout_64x176; + using BLayout = GMMA::ABLayout< 64, 16>; + using CLayout = GMMA::CLayout_64x64; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) - template < GMMA::Major tnspA, GMMA::Major tnspB, GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -using SM90_64x176x16_F16F16F16_RS = SM90::GMMA::MMA_64x176x16_F16F16F16_RS; +using SM90_64x64x16_F32F16F16_RS = SM90::GMMA::MMA_64x64x16_F32F16F16_RS; template -struct MMA_Traits> +struct MMA_Traits> { - using ValTypeD = half_t; + using ValTypeD = float; using ValTypeA = half_t; using ValTypeB = half_t; - using ValTypeC = half_t; + using ValTypeC = float; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_176,_16>; + using Shape_MNK = Shape<_64,_64,_16>; using ThrID = Layout<_128>; using ALayout = GMMA::ALayout_64x16; - using BLayout = GMMA::ABLayout<176, 16>; - using CLayout = GMMA::CLayout_64x176; + using BLayout = GMMA::ABLayout< 64, 16>; + using CLayout = GMMA::CLayout_64x64; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// - template < GMMA::Major tnspA, GMMA::Major tnspB, GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -using SM90_64x192x16_F16F16F16_SS = SM90::GMMA::MMA_64x192x16_F16F16F16_SS; +using SM90_64x96x16_F32F16F16_SS = SM90::GMMA::MMA_64x96x16_F32F16F16_SS; template -struct MMA_Traits> +struct MMA_Traits> { - using ValTypeD = half_t; + using ValTypeD = float; using ValTypeA = half_t; using ValTypeB = half_t; - using ValTypeC = half_t; + using ValTypeC = float; using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_192,_16>; + using Shape_MNK = Shape<_64,_96,_16>; using ThrID = Layout<_128>; using ALayout = GMMA::ABLayout< 64, 16>; - using BLayout = GMMA::ABLayout<192, 16>; - using CLayout = GMMA::CLayout_64x192; + using BLayout = GMMA::ABLayout< 96, 16>; + using CLayout = GMMA::CLayout_64x96; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; //////////////////////////////////////////////////////////////////////////////////////////////////// - template < GMMA::Major tnspA, GMMA::Major tnspB, GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -using SM90_64x192x16_F16F16F16_RS = SM90::GMMA::MMA_64x192x16_F16F16F16_RS; +using SM90_64x96x16_F32F16F16_RS = SM90::GMMA::MMA_64x96x16_F32F16F16_RS; template -struct MMA_Traits> +struct MMA_Traits> { - using ValTypeD = half_t; + using ValTypeD = float; using ValTypeA = half_t; using ValTypeB = half_t; - using ValTypeC = half_t; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_192,_16>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x16; - using BLayout = GMMA::ABLayout<192, 16>; - using CLayout = GMMA::CLayout_64x192; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) - -template < - GMMA::Major tnspA, - GMMA::Major tnspB, - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -using SM90_64x208x16_F16F16F16_SS = SM90::GMMA::MMA_64x208x16_F16F16F16_SS; - -template -struct MMA_Traits> -{ - using ValTypeD = half_t; - using ValTypeA = half_t; - using ValTypeB = half_t; - using ValTypeC = half_t; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_208,_16>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 16>; - using BLayout = GMMA::ABLayout<208, 16>; - using CLayout = GMMA::CLayout_64x208; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) - -template < - GMMA::Major tnspA, - GMMA::Major tnspB, - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -using SM90_64x208x16_F16F16F16_RS = SM90::GMMA::MMA_64x208x16_F16F16F16_RS; - -template -struct MMA_Traits> -{ - using ValTypeD = half_t; - using ValTypeA = half_t; - using ValTypeB = half_t; - using ValTypeC = half_t; + using ValTypeC = float; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_208,_16>; + using Shape_MNK = Shape<_64,_96,_16>; using ThrID = Layout<_128>; using ALayout = GMMA::ALayout_64x16; - using BLayout = GMMA::ABLayout<208, 16>; - using CLayout = GMMA::CLayout_64x208; + using BLayout = GMMA::ABLayout< 96, 16>; + using CLayout = GMMA::CLayout_64x96; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) - template < GMMA::Major tnspA, GMMA::Major tnspB, GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -using SM90_64x224x16_F16F16F16_SS = SM90::GMMA::MMA_64x224x16_F16F16F16_SS; +using SM90_64x128x16_F32F16F16_SS = SM90::GMMA::MMA_64x128x16_F32F16F16_SS; template -struct MMA_Traits> +struct MMA_Traits> { - using ValTypeD = half_t; + using ValTypeD = float; using ValTypeA = half_t; using ValTypeB = half_t; - using ValTypeC = half_t; + using ValTypeC = float; using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_224,_16>; + using Shape_MNK = Shape<_64,_128,_16>; using ThrID = Layout<_128>; using ALayout = GMMA::ABLayout< 64, 16>; - using BLayout = GMMA::ABLayout<224, 16>; - using CLayout = GMMA::CLayout_64x224; + using BLayout = GMMA::ABLayout<128, 16>; + using CLayout = GMMA::CLayout_64x128; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) - template < GMMA::Major tnspA, GMMA::Major tnspB, GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -using SM90_64x224x16_F16F16F16_RS = SM90::GMMA::MMA_64x224x16_F16F16F16_RS; +using SM90_64x128x16_F32F16F16_RS = SM90::GMMA::MMA_64x128x16_F32F16F16_RS; template -struct MMA_Traits> +struct MMA_Traits> { - using ValTypeD = half_t; + using ValTypeD = float; using ValTypeA = half_t; using ValTypeB = half_t; - using ValTypeC = half_t; + using ValTypeC = float; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_224,_16>; + using Shape_MNK = Shape<_64,_128,_16>; using ThrID = Layout<_128>; using ALayout = GMMA::ALayout_64x16; - using BLayout = GMMA::ABLayout<224, 16>; - using CLayout = GMMA::CLayout_64x224; + using BLayout = GMMA::ABLayout<128, 16>; + using CLayout = GMMA::CLayout_64x128; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) - template < GMMA::Major tnspA, GMMA::Major tnspB, GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -using SM90_64x240x16_F16F16F16_SS = SM90::GMMA::MMA_64x240x16_F16F16F16_SS; +using SM90_64x192x16_F32F16F16_SS = SM90::GMMA::MMA_64x192x16_F32F16F16_SS; template -struct MMA_Traits> +struct MMA_Traits> { - using ValTypeD = half_t; + using ValTypeD = float; using ValTypeA = half_t; using ValTypeB = half_t; - using ValTypeC = half_t; + using ValTypeC = float; using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_240,_16>; + using Shape_MNK = Shape<_64,_192,_16>; using ThrID = Layout<_128>; using ALayout = GMMA::ABLayout< 64, 16>; - using BLayout = GMMA::ABLayout<240, 16>; - using CLayout = GMMA::CLayout_64x240; + using BLayout = GMMA::ABLayout<192, 16>; + using CLayout = GMMA::CLayout_64x192; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) - template < GMMA::Major tnspA, GMMA::Major tnspB, GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -using SM90_64x240x16_F16F16F16_RS = SM90::GMMA::MMA_64x240x16_F16F16F16_RS; +using SM90_64x192x16_F32F16F16_RS = SM90::GMMA::MMA_64x192x16_F32F16F16_RS; template -struct MMA_Traits> +struct MMA_Traits> { - using ValTypeD = half_t; + using ValTypeD = float; using ValTypeA = half_t; using ValTypeB = half_t; - using ValTypeC = half_t; + using ValTypeC = float; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_240,_16>; + using Shape_MNK = Shape<_64,_192,_16>; using ThrID = Layout<_128>; using ALayout = GMMA::ALayout_64x16; - using BLayout = GMMA::ABLayout<240, 16>; - using CLayout = GMMA::CLayout_64x240; + using BLayout = GMMA::ABLayout<192, 16>; + using CLayout = GMMA::CLayout_64x192; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// - template < GMMA::Major tnspA, GMMA::Major tnspB, GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -using SM90_64x256x16_F16F16F16_SS = SM90::GMMA::MMA_64x256x16_F16F16F16_SS; +using SM90_64x256x16_F32F16F16_SS = SM90::GMMA::MMA_64x256x16_F32F16F16_SS; template -struct MMA_Traits> +struct MMA_Traits> { - using ValTypeD = half_t; + using ValTypeD = float; using ValTypeA = half_t; using ValTypeB = half_t; - using ValTypeC = half_t; + using ValTypeC = float; using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; @@ -1561,22 +1395,21 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// - template < GMMA::Major tnspA, GMMA::Major tnspB, GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -using SM90_64x256x16_F16F16F16_RS = SM90::GMMA::MMA_64x256x16_F16F16F16_RS; +using SM90_64x256x16_F32F16F16_RS = SM90::GMMA::MMA_64x256x16_F32F16F16_RS; template -struct MMA_Traits> +struct MMA_Traits> { - using ValTypeD = half_t; + using ValTypeD = float; using ValTypeA = half_t; using ValTypeB = half_t; - using ValTypeC = half_t; + using ValTypeC = float; using FrgTypeB = GMMA::smem_desc; @@ -1591,21 +1424,20 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// - template < GMMA::Major tnspA, GMMA::Major tnspB, GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -using SM90_64x8x16_F32F16F16_SS = SM90::GMMA::MMA_64x8x16_F32F16F16_SS; +using SM90_64x8x16_F32BF16BF16_SS = SM90::GMMA::MMA_64x8x16_F32BF16BF16_SS; template -struct MMA_Traits> +struct MMA_Traits> { using ValTypeD = float; - using ValTypeA = half_t; - using ValTypeB = half_t; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; using ValTypeC = float; using FrgTypeA = GMMA::smem_desc; @@ -1622,21 +1454,20 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// - template < GMMA::Major tnspA, GMMA::Major tnspB, GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -using SM90_64x8x16_F32F16F16_RS = SM90::GMMA::MMA_64x8x16_F32F16F16_RS; +using SM90_64x8x16_F32BF16BF16_RS = SM90::GMMA::MMA_64x8x16_F32BF16BF16_RS; template -struct MMA_Traits> +struct MMA_Traits> { using ValTypeD = float; - using ValTypeA = half_t; - using ValTypeB = half_t; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; using ValTypeC = float; using FrgTypeB = GMMA::smem_desc; @@ -1652,21 +1483,20 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// - template < GMMA::Major tnspA, GMMA::Major tnspB, GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -using SM90_64x16x16_F32F16F16_SS = SM90::GMMA::MMA_64x16x16_F32F16F16_SS; +using SM90_64x16x16_F32BF16BF16_SS = SM90::GMMA::MMA_64x16x16_F32BF16BF16_SS; template -struct MMA_Traits> +struct MMA_Traits> { using ValTypeD = float; - using ValTypeA = half_t; - using ValTypeB = half_t; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; using ValTypeC = float; using FrgTypeA = GMMA::smem_desc; @@ -1683,21 +1513,20 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// - template < GMMA::Major tnspA, GMMA::Major tnspB, GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -using SM90_64x16x16_F32F16F16_RS = SM90::GMMA::MMA_64x16x16_F32F16F16_RS; +using SM90_64x16x16_F32BF16BF16_RS = SM90::GMMA::MMA_64x16x16_F32BF16BF16_RS; template -struct MMA_Traits> +struct MMA_Traits> { using ValTypeD = float; - using ValTypeA = half_t; - using ValTypeB = half_t; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; using ValTypeC = float; using FrgTypeB = GMMA::smem_desc; @@ -1713,21 +1542,20 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// - template < GMMA::Major tnspA, GMMA::Major tnspB, GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -using SM90_64x32x16_F32F16F16_SS = SM90::GMMA::MMA_64x32x16_F32F16F16_SS; +using SM90_64x32x16_F32BF16BF16_SS = SM90::GMMA::MMA_64x32x16_F32BF16BF16_SS; template -struct MMA_Traits> +struct MMA_Traits> { using ValTypeD = float; - using ValTypeA = half_t; - using ValTypeB = half_t; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; using ValTypeC = float; using FrgTypeA = GMMA::smem_desc; @@ -1744,21 +1572,20 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// - template < GMMA::Major tnspA, GMMA::Major tnspB, GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -using SM90_64x32x16_F32F16F16_RS = SM90::GMMA::MMA_64x32x16_F32F16F16_RS; +using SM90_64x32x16_F32BF16BF16_RS = SM90::GMMA::MMA_64x32x16_F32BF16BF16_RS; template -struct MMA_Traits> +struct MMA_Traits> { using ValTypeD = float; - using ValTypeA = half_t; - using ValTypeB = half_t; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; using ValTypeC = float; using FrgTypeB = GMMA::smem_desc; @@ -1774,893 +1601,760 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) - template < GMMA::Major tnspA, GMMA::Major tnspB, GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -using SM90_64x40x16_F32F16F16_SS = SM90::GMMA::MMA_64x40x16_F32F16F16_SS; +using SM90_64x64x16_F32BF16BF16_SS = SM90::GMMA::MMA_64x64x16_F32BF16BF16_SS; template -struct MMA_Traits> +struct MMA_Traits> { - using ValTypeD = float; - using ValTypeA = half_t; - using ValTypeB = half_t; - using ValTypeC = float; + using ValTypeD = float; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; + using ValTypeC = float; - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,Int<40>,_16>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 16>; - using BLayout = GMMA::ABLayout< 40, 16>; - using CLayout = GMMA::CLayout_64x40; + using Shape_MNK = Shape<_64,_64,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout< 64, 16>; + using CLayout = GMMA::CLayout_64x64; - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) - template < GMMA::Major tnspA, GMMA::Major tnspB, GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -using SM90_64x48x16_F32F16F16_SS = SM90::GMMA::MMA_64x48x16_F32F16F16_SS; +using SM90_64x64x16_F32BF16BF16_RS = SM90::GMMA::MMA_64x64x16_F32BF16BF16_RS; template -struct MMA_Traits> +struct MMA_Traits> { using ValTypeD = float; - using ValTypeA = half_t; - using ValTypeB = half_t; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; using ValTypeC = float; - using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_48,_16>; + using Shape_MNK = Shape<_64,_64,_16>; using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 16>; - using BLayout = GMMA::ABLayout< 48, 16>; - using CLayout = GMMA::CLayout_64x48; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout< 64, 16>; + using CLayout = GMMA::CLayout_64x64; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) - template < GMMA::Major tnspA, GMMA::Major tnspB, GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -using SM90_64x48x16_F32F16F16_RS = SM90::GMMA::MMA_64x48x16_F32F16F16_RS; +using SM90_64x96x16_F32BF16BF16_SS = SM90::GMMA::MMA_64x96x16_F32BF16BF16_SS; template -struct MMA_Traits> +struct MMA_Traits> { using ValTypeD = float; - using ValTypeA = half_t; - using ValTypeB = half_t; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; using ValTypeC = float; + using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_48,_16>; + using Shape_MNK = Shape<_64,_96,_16>; using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x16; - using BLayout = GMMA::ABLayout< 48, 16>; - using CLayout = GMMA::CLayout_64x48; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout< 96, 16>; + using CLayout = GMMA::CLayout_64x96; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// - template < GMMA::Major tnspA, GMMA::Major tnspB, GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -using SM90_64x64x16_F32F16F16_SS = SM90::GMMA::MMA_64x64x16_F32F16F16_SS; +using SM90_64x96x16_F32BF16BF16_RS = SM90::GMMA::MMA_64x96x16_F32BF16BF16_RS; template -struct MMA_Traits> +struct MMA_Traits> { using ValTypeD = float; - using ValTypeA = half_t; - using ValTypeB = half_t; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; using ValTypeC = float; - using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_64,_16>; + using Shape_MNK = Shape<_64,_96,_16>; using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 16>; - using BLayout = GMMA::ABLayout< 64, 16>; - using CLayout = GMMA::CLayout_64x64; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout< 96, 16>; + using CLayout = GMMA::CLayout_64x96; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; //////////////////////////////////////////////////////////////////////////////////////////////////// - template < GMMA::Major tnspA, GMMA::Major tnspB, GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -using SM90_64x64x16_F32F16F16_RS = SM90::GMMA::MMA_64x64x16_F32F16F16_RS; +using SM90_64x128x16_F32BF16BF16_SS = SM90::GMMA::MMA_64x128x16_F32BF16BF16_SS; template -struct MMA_Traits> +struct MMA_Traits> { using ValTypeD = float; - using ValTypeA = half_t; - using ValTypeB = half_t; - using ValTypeC = float; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_64,_16>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x16; - using BLayout = GMMA::ABLayout< 64, 16>; - using CLayout = GMMA::CLayout_64x64; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) - -template < - GMMA::Major tnspA, - GMMA::Major tnspB, - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -using SM90_64x80x16_F32F16F16_SS = SM90::GMMA::MMA_64x80x16_F32F16F16_SS; - -template -struct MMA_Traits> -{ - using ValTypeD = float; - using ValTypeA = half_t; - using ValTypeB = half_t; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; using ValTypeC = float; using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_80,_16>; + using Shape_MNK = Shape<_64,_128,_16>; using ThrID = Layout<_128>; using ALayout = GMMA::ABLayout< 64, 16>; - using BLayout = GMMA::ABLayout< 80, 16>; - using CLayout = GMMA::CLayout_64x80; + using BLayout = GMMA::ABLayout<128, 16>; + using CLayout = GMMA::CLayout_64x128; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) - template < GMMA::Major tnspA, GMMA::Major tnspB, GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -using SM90_64x80x16_F32F16F16_RS = SM90::GMMA::MMA_64x80x16_F32F16F16_RS; +using SM90_64x128x16_F32BF16BF16_RS = SM90::GMMA::MMA_64x128x16_F32BF16BF16_RS; template -struct MMA_Traits> +struct MMA_Traits> { using ValTypeD = float; - using ValTypeA = half_t; - using ValTypeB = half_t; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; using ValTypeC = float; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_80,_16>; + using Shape_MNK = Shape<_64,_128,_16>; using ThrID = Layout<_128>; using ALayout = GMMA::ALayout_64x16; - using BLayout = GMMA::ABLayout< 80, 16>; - using CLayout = GMMA::CLayout_64x80; + using BLayout = GMMA::ABLayout<128, 16>; + using CLayout = GMMA::CLayout_64x128; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// - template < GMMA::Major tnspA, GMMA::Major tnspB, GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -using SM90_64x96x16_F32F16F16_SS = SM90::GMMA::MMA_64x96x16_F32F16F16_SS; +using SM90_64x192x16_F32BF16BF16_SS = SM90::GMMA::MMA_64x192x16_F32BF16BF16_SS; template -struct MMA_Traits> +struct MMA_Traits> { using ValTypeD = float; - using ValTypeA = half_t; - using ValTypeB = half_t; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; using ValTypeC = float; using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_96,_16>; + using Shape_MNK = Shape<_64,_192,_16>; using ThrID = Layout<_128>; using ALayout = GMMA::ABLayout< 64, 16>; - using BLayout = GMMA::ABLayout< 96, 16>; - using CLayout = GMMA::CLayout_64x96; + using BLayout = GMMA::ABLayout<192, 16>; + using CLayout = GMMA::CLayout_64x192; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; //////////////////////////////////////////////////////////////////////////////////////////////////// - template < GMMA::Major tnspA, GMMA::Major tnspB, GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -using SM90_64x96x16_F32F16F16_RS = SM90::GMMA::MMA_64x96x16_F32F16F16_RS; +using SM90_64x192x16_F32BF16BF16_RS = SM90::GMMA::MMA_64x192x16_F32BF16BF16_RS; template -struct MMA_Traits> +struct MMA_Traits> { using ValTypeD = float; - using ValTypeA = half_t; - using ValTypeB = half_t; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; using ValTypeC = float; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_96,_16>; + using Shape_MNK = Shape<_64,_192,_16>; using ThrID = Layout<_128>; using ALayout = GMMA::ALayout_64x16; - using BLayout = GMMA::ABLayout< 96, 16>; - using CLayout = GMMA::CLayout_64x96; + using BLayout = GMMA::ABLayout<192, 16>; + using CLayout = GMMA::CLayout_64x192; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) - template < GMMA::Major tnspA, GMMA::Major tnspB, GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -using SM90_64x112x16_F32F16F16_SS = SM90::GMMA::MMA_64x112x16_F32F16F16_SS; +using SM90_64x256x16_F32BF16BF16_SS = SM90::GMMA::MMA_64x256x16_F32BF16BF16_SS; template -struct MMA_Traits> +struct MMA_Traits> { using ValTypeD = float; - using ValTypeA = half_t; - using ValTypeB = half_t; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; using ValTypeC = float; using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_112,_16>; + using Shape_MNK = Shape<_64,_256,_16>; using ThrID = Layout<_128>; using ALayout = GMMA::ABLayout< 64, 16>; - using BLayout = GMMA::ABLayout<112, 16>; - using CLayout = GMMA::CLayout_64x112; + using BLayout = GMMA::ABLayout<256, 16>; + using CLayout = GMMA::CLayout_64x256; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) - template < GMMA::Major tnspA, GMMA::Major tnspB, GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -using SM90_64x112x16_F32F16F16_RS = SM90::GMMA::MMA_64x112x16_F32F16F16_RS; +using SM90_64x256x16_F32BF16BF16_RS = SM90::GMMA::MMA_64x256x16_F32BF16BF16_RS; template -struct MMA_Traits> +struct MMA_Traits> { using ValTypeD = float; - using ValTypeA = half_t; - using ValTypeB = half_t; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; using ValTypeC = float; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_112,_16>; + using Shape_MNK = Shape<_64,_256,_16>; using ThrID = Layout<_128>; using ALayout = GMMA::ALayout_64x16; - using BLayout = GMMA::ABLayout<112, 16>; - using CLayout = GMMA::CLayout_64x112; + using BLayout = GMMA::ABLayout<256, 16>; + using CLayout = GMMA::CLayout_64x256; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// - template < - GMMA::Major tnspA, - GMMA::Major tnspB, GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -using SM90_64x128x16_F32F16F16_SS = SM90::GMMA::MMA_64x128x16_F32F16F16_SS; +using SM90_64x8x8_F32TF32TF32_SS_TN = SM90::GMMA::MMA_64x8x8_F32TF32TF32_SS_TN; -template -struct MMA_Traits> +template +struct MMA_Traits> { using ValTypeD = float; - using ValTypeA = half_t; - using ValTypeB = half_t; + using ValTypeA = tfloat32_t; + using ValTypeB = tfloat32_t; using ValTypeC = float; - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_128,_16>; + using Shape_MNK = Shape<_64,_8,_8>; using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 16>; - using BLayout = GMMA::ABLayout<128, 16>; - using CLayout = GMMA::CLayout_64x128; + using ALayout = GMMA::ABLayout< 64, 8>; + using BLayout = GMMA::ABLayout< 8, 8>; + using CLayout = GMMA::CLayout_64x8; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; //////////////////////////////////////////////////////////////////////////////////////////////////// - template < - GMMA::Major tnspA, - GMMA::Major tnspB, GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -using SM90_64x128x16_F32F16F16_RS = SM90::GMMA::MMA_64x128x16_F32F16F16_RS; +using SM90_64x8x8_F32TF32TF32_RS_TN = SM90::GMMA::MMA_64x8x8_F32TF32TF32_RS_TN; -template -struct MMA_Traits> +template +struct MMA_Traits> { using ValTypeD = float; - using ValTypeA = half_t; - using ValTypeB = half_t; + using ValTypeA = tfloat32_t; + using ValTypeB = tfloat32_t; using ValTypeC = float; - using FrgTypeB = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_128,_16>; + using Shape_MNK = Shape<_64,_8,_8>; using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x16; - using BLayout = GMMA::ABLayout<128, 16>; - using CLayout = GMMA::CLayout_64x128; + using ALayout = GMMA::ALayout_64x8; + using BLayout = GMMA::ABLayout< 8, 8>; + using CLayout = GMMA::CLayout_64x8; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) - template < - GMMA::Major tnspA, - GMMA::Major tnspB, GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -using SM90_64x144x16_F32F16F16_SS = SM90::GMMA::MMA_64x144x16_F32F16F16_SS; +using SM90_64x16x8_F32TF32TF32_SS_TN = SM90::GMMA::MMA_64x16x8_F32TF32TF32_SS_TN; -template -struct MMA_Traits> +template +struct MMA_Traits> { using ValTypeD = float; - using ValTypeA = half_t; - using ValTypeB = half_t; + using ValTypeA = tfloat32_t; + using ValTypeB = tfloat32_t; using ValTypeC = float; - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_144,_16>; + using Shape_MNK = Shape<_64,_16,_8>; using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 16>; - using BLayout = GMMA::ABLayout<144, 16>; - using CLayout = GMMA::CLayout_64x144; + using ALayout = GMMA::ABLayout< 64, 8>; + using BLayout = GMMA::ABLayout< 16, 8>; + using CLayout = GMMA::CLayout_64x16; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) - template < - GMMA::Major tnspA, - GMMA::Major tnspB, GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -using SM90_64x144x16_F32F16F16_RS = SM90::GMMA::MMA_64x144x16_F32F16F16_RS; +using SM90_64x16x8_F32TF32TF32_RS_TN = SM90::GMMA::MMA_64x16x8_F32TF32TF32_RS_TN; -template -struct MMA_Traits> +template +struct MMA_Traits> { using ValTypeD = float; - using ValTypeA = half_t; - using ValTypeB = half_t; + using ValTypeA = tfloat32_t; + using ValTypeB = tfloat32_t; using ValTypeC = float; - using FrgTypeB = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_144,_16>; + using Shape_MNK = Shape<_64,_16,_8>; using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x16; - using BLayout = GMMA::ABLayout<144, 16>; - using CLayout = GMMA::CLayout_64x144; + using ALayout = GMMA::ALayout_64x8; + using BLayout = GMMA::ABLayout< 16, 8>; + using CLayout = GMMA::CLayout_64x16; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) - template < - GMMA::Major tnspA, - GMMA::Major tnspB, GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -using SM90_64x160x16_F32F16F16_SS = SM90::GMMA::MMA_64x160x16_F32F16F16_SS; +using SM90_64x32x8_F32TF32TF32_SS_TN = SM90::GMMA::MMA_64x32x8_F32TF32TF32_SS_TN; -template -struct MMA_Traits> +template +struct MMA_Traits> { using ValTypeD = float; - using ValTypeA = half_t; - using ValTypeB = half_t; + using ValTypeA = tfloat32_t; + using ValTypeB = tfloat32_t; using ValTypeC = float; - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_160,_16>; + using Shape_MNK = Shape<_64,_32,_8>; using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 16>; - using BLayout = GMMA::ABLayout<160, 16>; - using CLayout = GMMA::CLayout_64x160; + using ALayout = GMMA::ABLayout< 64, 8>; + using BLayout = GMMA::ABLayout< 32, 8>; + using CLayout = GMMA::CLayout_64x32; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) - template < - GMMA::Major tnspA, - GMMA::Major tnspB, GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -using SM90_64x160x16_F32F16F16_RS = SM90::GMMA::MMA_64x160x16_F32F16F16_RS; +using SM90_64x32x8_F32TF32TF32_RS_TN = SM90::GMMA::MMA_64x32x8_F32TF32TF32_RS_TN; -template -struct MMA_Traits> +template +struct MMA_Traits> { using ValTypeD = float; - using ValTypeA = half_t; - using ValTypeB = half_t; + using ValTypeA = tfloat32_t; + using ValTypeB = tfloat32_t; using ValTypeC = float; - using FrgTypeB = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_160,_16>; + using Shape_MNK = Shape<_64,_32,_8>; using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x16; - using BLayout = GMMA::ABLayout<160, 16>; - using CLayout = GMMA::CLayout_64x160; + using ALayout = GMMA::ALayout_64x8; + using BLayout = GMMA::ABLayout< 32, 8>; + using CLayout = GMMA::CLayout_64x32; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) - template < - GMMA::Major tnspA, - GMMA::Major tnspB, GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -using SM90_64x176x16_F32F16F16_SS = SM90::GMMA::MMA_64x176x16_F32F16F16_SS; +using SM90_64x64x8_F32TF32TF32_SS_TN = SM90::GMMA::MMA_64x64x8_F32TF32TF32_SS_TN; -template -struct MMA_Traits> +template +struct MMA_Traits> { using ValTypeD = float; - using ValTypeA = half_t; - using ValTypeB = half_t; + using ValTypeA = tfloat32_t; + using ValTypeB = tfloat32_t; using ValTypeC = float; - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_176,_16>; + using Shape_MNK = Shape<_64,_64,_8>; using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 16>; - using BLayout = GMMA::ABLayout<176, 16>; - using CLayout = GMMA::CLayout_64x176; + using ALayout = GMMA::ABLayout< 64, 8>; + using BLayout = GMMA::ABLayout< 64, 8>; + using CLayout = GMMA::CLayout_64x64; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) - template < - GMMA::Major tnspA, - GMMA::Major tnspB, GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -using SM90_64x176x16_F32F16F16_RS = SM90::GMMA::MMA_64x176x16_F32F16F16_RS; +using SM90_64x64x8_F32TF32TF32_RS_TN = SM90::GMMA::MMA_64x64x8_F32TF32TF32_RS_TN; -template -struct MMA_Traits> +template +struct MMA_Traits> { using ValTypeD = float; - using ValTypeA = half_t; - using ValTypeB = half_t; + using ValTypeA = tfloat32_t; + using ValTypeB = tfloat32_t; using ValTypeC = float; - using FrgTypeB = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_176,_16>; + using Shape_MNK = Shape<_64,_64,_8>; using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x16; - using BLayout = GMMA::ABLayout<176, 16>; - using CLayout = GMMA::CLayout_64x176; + using ALayout = GMMA::ALayout_64x8; + using BLayout = GMMA::ABLayout< 64, 8>; + using CLayout = GMMA::CLayout_64x64; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// - template < - GMMA::Major tnspA, - GMMA::Major tnspB, GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -using SM90_64x192x16_F32F16F16_SS = SM90::GMMA::MMA_64x192x16_F32F16F16_SS; +using SM90_64x96x8_F32TF32TF32_SS_TN = SM90::GMMA::MMA_64x96x8_F32TF32TF32_SS_TN; -template -struct MMA_Traits> +template +struct MMA_Traits> { using ValTypeD = float; - using ValTypeA = half_t; - using ValTypeB = half_t; + using ValTypeA = tfloat32_t; + using ValTypeB = tfloat32_t; using ValTypeC = float; - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_192,_16>; + using Shape_MNK = Shape<_64,_96,_8>; using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 16>; - using BLayout = GMMA::ABLayout<192, 16>; - using CLayout = GMMA::CLayout_64x192; + using ALayout = GMMA::ABLayout< 64, 8>; + using BLayout = GMMA::ABLayout< 96, 8>; + using CLayout = GMMA::CLayout_64x96; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; //////////////////////////////////////////////////////////////////////////////////////////////////// - template < - GMMA::Major tnspA, - GMMA::Major tnspB, GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -using SM90_64x192x16_F32F16F16_RS = SM90::GMMA::MMA_64x192x16_F32F16F16_RS; +using SM90_64x96x8_F32TF32TF32_RS_TN = SM90::GMMA::MMA_64x96x8_F32TF32TF32_RS_TN; -template -struct MMA_Traits> +template +struct MMA_Traits> { using ValTypeD = float; - using ValTypeA = half_t; - using ValTypeB = half_t; + using ValTypeA = tfloat32_t; + using ValTypeB = tfloat32_t; using ValTypeC = float; - using FrgTypeB = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_192,_16>; + using Shape_MNK = Shape<_64,_96,_8>; using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x16; - using BLayout = GMMA::ABLayout<192, 16>; - using CLayout = GMMA::CLayout_64x192; + using ALayout = GMMA::ALayout_64x8; + using BLayout = GMMA::ABLayout< 96, 8>; + using CLayout = GMMA::CLayout_64x96; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) - template < - GMMA::Major tnspA, - GMMA::Major tnspB, GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -using SM90_64x208x16_F32F16F16_SS = SM90::GMMA::MMA_64x208x16_F32F16F16_SS; +using SM90_64x128x8_F32TF32TF32_SS_TN = SM90::GMMA::MMA_64x128x8_F32TF32TF32_SS_TN; -template -struct MMA_Traits> +template +struct MMA_Traits> { using ValTypeD = float; - using ValTypeA = half_t; - using ValTypeB = half_t; + using ValTypeA = tfloat32_t; + using ValTypeB = tfloat32_t; using ValTypeC = float; - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_208,_16>; + using Shape_MNK = Shape<_64,_128,_8>; using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 16>; - using BLayout = GMMA::ABLayout<208, 16>; - using CLayout = GMMA::CLayout_64x208; + using ALayout = GMMA::ABLayout< 64, 8>; + using BLayout = GMMA::ABLayout<128, 8>; + using CLayout = GMMA::CLayout_64x128; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) - template < - GMMA::Major tnspA, - GMMA::Major tnspB, GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -using SM90_64x208x16_F32F16F16_RS = SM90::GMMA::MMA_64x208x16_F32F16F16_RS; +using SM90_64x128x8_F32TF32TF32_RS_TN = SM90::GMMA::MMA_64x128x8_F32TF32TF32_RS_TN; -template -struct MMA_Traits> +template +struct MMA_Traits> { using ValTypeD = float; - using ValTypeA = half_t; - using ValTypeB = half_t; + using ValTypeA = tfloat32_t; + using ValTypeB = tfloat32_t; using ValTypeC = float; - using FrgTypeB = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_208,_16>; + using Shape_MNK = Shape<_64,_128,_8>; using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x16; - using BLayout = GMMA::ABLayout<208, 16>; - using CLayout = GMMA::CLayout_64x208; + using ALayout = GMMA::ALayout_64x8; + using BLayout = GMMA::ABLayout<128, 8>; + using CLayout = GMMA::CLayout_64x128; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) - template < - GMMA::Major tnspA, - GMMA::Major tnspB, GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -using SM90_64x224x16_F32F16F16_SS = SM90::GMMA::MMA_64x224x16_F32F16F16_SS; +using SM90_64x192x8_F32TF32TF32_SS_TN = SM90::GMMA::MMA_64x192x8_F32TF32TF32_SS_TN; -template -struct MMA_Traits> +template +struct MMA_Traits> { using ValTypeD = float; - using ValTypeA = half_t; - using ValTypeB = half_t; + using ValTypeA = tfloat32_t; + using ValTypeB = tfloat32_t; using ValTypeC = float; - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_224,_16>; + using Shape_MNK = Shape<_64,_192,_8>; using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 16>; - using BLayout = GMMA::ABLayout<224, 16>; - using CLayout = GMMA::CLayout_64x224; + using ALayout = GMMA::ABLayout< 64, 8>; + using BLayout = GMMA::ABLayout<192, 8>; + using CLayout = GMMA::CLayout_64x192; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) - template < - GMMA::Major tnspA, - GMMA::Major tnspB, GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -using SM90_64x224x16_F32F16F16_RS = SM90::GMMA::MMA_64x224x16_F32F16F16_RS; +using SM90_64x192x8_F32TF32TF32_RS_TN = SM90::GMMA::MMA_64x192x8_F32TF32TF32_RS_TN; -template -struct MMA_Traits> +template +struct MMA_Traits> { using ValTypeD = float; - using ValTypeA = half_t; - using ValTypeB = half_t; + using ValTypeA = tfloat32_t; + using ValTypeB = tfloat32_t; using ValTypeC = float; - using FrgTypeB = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_224,_16>; + using Shape_MNK = Shape<_64,_192,_8>; using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x16; - using BLayout = GMMA::ABLayout<224, 16>; - using CLayout = GMMA::CLayout_64x224; + using ALayout = GMMA::ALayout_64x8; + using BLayout = GMMA::ABLayout<192, 8>; + using CLayout = GMMA::CLayout_64x192; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) - template < - GMMA::Major tnspA, - GMMA::Major tnspB, GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -using SM90_64x240x16_F32F16F16_SS = SM90::GMMA::MMA_64x240x16_F32F16F16_SS; +using SM90_64x256x8_F32TF32TF32_SS_TN = SM90::GMMA::MMA_64x256x8_F32TF32TF32_SS_TN; -template -struct MMA_Traits> +template +struct MMA_Traits> { using ValTypeD = float; - using ValTypeA = half_t; - using ValTypeB = half_t; + using ValTypeA = tfloat32_t; + using ValTypeB = tfloat32_t; using ValTypeC = float; - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_240,_16>; + using Shape_MNK = Shape<_64,_256,_8>; using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 16>; - using BLayout = GMMA::ABLayout<240, 16>; - using CLayout = GMMA::CLayout_64x240; + using ALayout = GMMA::ABLayout< 64, 8>; + using BLayout = GMMA::ABLayout<256, 8>; + using CLayout = GMMA::CLayout_64x256; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) - template < - GMMA::Major tnspA, - GMMA::Major tnspB, GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -using SM90_64x240x16_F32F16F16_RS = SM90::GMMA::MMA_64x240x16_F32F16F16_RS; +using SM90_64x256x8_F32TF32TF32_RS_TN = SM90::GMMA::MMA_64x256x8_F32TF32TF32_RS_TN; -template -struct MMA_Traits> +template +struct MMA_Traits> { using ValTypeD = float; - using ValTypeA = half_t; - using ValTypeB = half_t; + using ValTypeA = tfloat32_t; + using ValTypeB = tfloat32_t; using ValTypeC = float; - using FrgTypeB = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_240,_16>; + using Shape_MNK = Shape<_64,_256,_8>; using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x16; - using BLayout = GMMA::ABLayout<240, 16>; - using CLayout = GMMA::CLayout_64x240; + using ALayout = GMMA::ALayout_64x8; + using BLayout = GMMA::ABLayout<256, 8>; + using CLayout = GMMA::CLayout_64x256; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -template < - GMMA::Major tnspA, - GMMA::Major tnspB, - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -using SM90_64x256x16_F32F16F16_SS = SM90::GMMA::MMA_64x256x16_F32F16F16_SS; +using SM90_64x8x32_S32S8S8_SS_TN = SM90::GMMA::MMA_64x8x32_S32S8S8_SS_TN; -template -struct MMA_Traits> +template <> +struct MMA_Traits { - using ValTypeD = float; - using ValTypeA = half_t; - using ValTypeB = half_t; - using ValTypeC = float; + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_256,_16>; + using Shape_MNK = Shape<_64,_8,_32>; using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 16>; - using BLayout = GMMA::ABLayout<256, 16>; - using CLayout = GMMA::CLayout_64x256; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 8, 32>; + using CLayout = GMMA::CLayout_64x8; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; @@ -2668,29 +2362,24 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// -template < - GMMA::Major tnspA, - GMMA::Major tnspB, - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -using SM90_64x256x16_F32F16F16_RS = SM90::GMMA::MMA_64x256x16_F32F16F16_RS; +using SM90_64x8x32_S32S8S8_SS_TN_SATURATE = SM90::GMMA::MMA_64x8x32_S32S8S8_SS_TN_SATURATE; -template -struct MMA_Traits> +template <> +struct MMA_Traits { - using ValTypeD = float; - using ValTypeA = half_t; - using ValTypeB = half_t; - using ValTypeC = float; + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; - using FrgTypeB = GMMA::smem_desc; + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_256,_16>; + using Shape_MNK = Shape<_64,_8,_32>; using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x16; - using BLayout = GMMA::ABLayout<256, 16>; - using CLayout = GMMA::CLayout_64x256; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 8, 32>; + using CLayout = GMMA::CLayout_64x8; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; @@ -2698,9884 +2387,24 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// -template < - GMMA::Major tnspA, - GMMA::Major tnspB, - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -using SM90_64x8x16_F32BF16BF16_SS = SM90::GMMA::MMA_64x8x16_F32BF16BF16_SS; - -template -struct MMA_Traits> -{ - using ValTypeD = float; - using ValTypeA = bfloat16_t; - using ValTypeB = bfloat16_t; - using ValTypeC = float; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_8,_16>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 16>; - using BLayout = GMMA::ABLayout< 8, 16>; - using CLayout = GMMA::CLayout_64x8; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - - -template < - GMMA::Major tnspA, - GMMA::Major tnspB, - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -using SM90_64x8x16_F32BF16BF16_RS = SM90::GMMA::MMA_64x8x16_F32BF16BF16_RS; - -template -struct MMA_Traits> -{ - using ValTypeD = float; - using ValTypeA = bfloat16_t; - using ValTypeB = bfloat16_t; - using ValTypeC = float; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_8,_16>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x16; - using BLayout = GMMA::ABLayout< 8, 16>; - using CLayout = GMMA::CLayout_64x8; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - - -template < - GMMA::Major tnspA, - GMMA::Major tnspB, - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -using SM90_64x16x16_F32BF16BF16_SS = SM90::GMMA::MMA_64x16x16_F32BF16BF16_SS; - -template -struct MMA_Traits> -{ - using ValTypeD = float; - using ValTypeA = bfloat16_t; - using ValTypeB = bfloat16_t; - using ValTypeC = float; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_16,_16>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 16>; - using BLayout = GMMA::ABLayout< 16, 16>; - using CLayout = GMMA::CLayout_64x16; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - - -template < - GMMA::Major tnspA, - GMMA::Major tnspB, - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -using SM90_64x16x16_F32BF16BF16_RS = SM90::GMMA::MMA_64x16x16_F32BF16BF16_RS; - -template -struct MMA_Traits> -{ - using ValTypeD = float; - using ValTypeA = bfloat16_t; - using ValTypeB = bfloat16_t; - using ValTypeC = float; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_16,_16>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x16; - using BLayout = GMMA::ABLayout< 16, 16>; - using CLayout = GMMA::CLayout_64x16; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - - -template < - GMMA::Major tnspA, - GMMA::Major tnspB, - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -using SM90_64x32x16_F32BF16BF16_SS = SM90::GMMA::MMA_64x32x16_F32BF16BF16_SS; - -template -struct MMA_Traits> -{ - using ValTypeD = float; - using ValTypeA = bfloat16_t; - using ValTypeB = bfloat16_t; - using ValTypeC = float; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_32,_16>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 16>; - using BLayout = GMMA::ABLayout< 32, 16>; - using CLayout = GMMA::CLayout_64x32; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - - -template < - GMMA::Major tnspA, - GMMA::Major tnspB, - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -using SM90_64x32x16_F32BF16BF16_RS = SM90::GMMA::MMA_64x32x16_F32BF16BF16_RS; - -template -struct MMA_Traits> -{ - using ValTypeD = float; - using ValTypeA = bfloat16_t; - using ValTypeB = bfloat16_t; - using ValTypeC = float; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_32,_16>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x16; - using BLayout = GMMA::ABLayout< 32, 16>; - using CLayout = GMMA::CLayout_64x32; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) - -template < - GMMA::Major tnspA, - GMMA::Major tnspB, - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -using SM90_64x40x16_F32BF16BF16_SS = SM90::GMMA::MMA_64x40x16_F32BF16BF16_SS; - -template -struct MMA_Traits> -{ - using ValTypeD = float; - using ValTypeA = bfloat16_t; - using ValTypeB = bfloat16_t; - using ValTypeC = float; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,Int<40>,_16>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 16>; - using BLayout = GMMA::ABLayout< 40, 16>; - using CLayout = GMMA::CLayout_64x40; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) - -template < - GMMA::Major tnspA, - GMMA::Major tnspB, - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -using SM90_64x48x16_F32BF16BF16_SS = SM90::GMMA::MMA_64x48x16_F32BF16BF16_SS; - -template -struct MMA_Traits> -{ - using ValTypeD = float; - using ValTypeA = bfloat16_t; - using ValTypeB = bfloat16_t; - using ValTypeC = float; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_48,_16>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 16>; - using BLayout = GMMA::ABLayout< 48, 16>; - using CLayout = GMMA::CLayout_64x48; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) - -template < - GMMA::Major tnspA, - GMMA::Major tnspB, - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -using SM90_64x48x16_F32BF16BF16_RS = SM90::GMMA::MMA_64x48x16_F32BF16BF16_RS; - -template -struct MMA_Traits> -{ - using ValTypeD = float; - using ValTypeA = bfloat16_t; - using ValTypeB = bfloat16_t; - using ValTypeC = float; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_48,_16>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x16; - using BLayout = GMMA::ABLayout< 48, 16>; - using CLayout = GMMA::CLayout_64x48; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - - -template < - GMMA::Major tnspA, - GMMA::Major tnspB, - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -using SM90_64x64x16_F32BF16BF16_SS = SM90::GMMA::MMA_64x64x16_F32BF16BF16_SS; - -template -struct MMA_Traits> -{ - using ValTypeD = float; - using ValTypeA = bfloat16_t; - using ValTypeB = bfloat16_t; - using ValTypeC = float; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_64,_16>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 16>; - using BLayout = GMMA::ABLayout< 64, 16>; - using CLayout = GMMA::CLayout_64x64; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - - -template < - GMMA::Major tnspA, - GMMA::Major tnspB, - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -using SM90_64x64x16_F32BF16BF16_RS = SM90::GMMA::MMA_64x64x16_F32BF16BF16_RS; - -template -struct MMA_Traits> -{ - using ValTypeD = float; - using ValTypeA = bfloat16_t; - using ValTypeB = bfloat16_t; - using ValTypeC = float; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_64,_16>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x16; - using BLayout = GMMA::ABLayout< 64, 16>; - using CLayout = GMMA::CLayout_64x64; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) - -template < - GMMA::Major tnspA, - GMMA::Major tnspB, - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -using SM90_64x80x16_F32BF16BF16_SS = SM90::GMMA::MMA_64x80x16_F32BF16BF16_SS; - -template -struct MMA_Traits> -{ - using ValTypeD = float; - using ValTypeA = bfloat16_t; - using ValTypeB = bfloat16_t; - using ValTypeC = float; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_80,_16>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 16>; - using BLayout = GMMA::ABLayout< 80, 16>; - using CLayout = GMMA::CLayout_64x80; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) - -template < - GMMA::Major tnspA, - GMMA::Major tnspB, - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -using SM90_64x80x16_F32BF16BF16_RS = SM90::GMMA::MMA_64x80x16_F32BF16BF16_RS; - -template -struct MMA_Traits> -{ - using ValTypeD = float; - using ValTypeA = bfloat16_t; - using ValTypeB = bfloat16_t; - using ValTypeC = float; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_80,_16>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x16; - using BLayout = GMMA::ABLayout< 80, 16>; - using CLayout = GMMA::CLayout_64x80; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - - -template < - GMMA::Major tnspA, - GMMA::Major tnspB, - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -using SM90_64x96x16_F32BF16BF16_SS = SM90::GMMA::MMA_64x96x16_F32BF16BF16_SS; - -template -struct MMA_Traits> -{ - using ValTypeD = float; - using ValTypeA = bfloat16_t; - using ValTypeB = bfloat16_t; - using ValTypeC = float; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_96,_16>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 16>; - using BLayout = GMMA::ABLayout< 96, 16>; - using CLayout = GMMA::CLayout_64x96; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - - -template < - GMMA::Major tnspA, - GMMA::Major tnspB, - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -using SM90_64x96x16_F32BF16BF16_RS = SM90::GMMA::MMA_64x96x16_F32BF16BF16_RS; - -template -struct MMA_Traits> -{ - using ValTypeD = float; - using ValTypeA = bfloat16_t; - using ValTypeB = bfloat16_t; - using ValTypeC = float; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_96,_16>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x16; - using BLayout = GMMA::ABLayout< 96, 16>; - using CLayout = GMMA::CLayout_64x96; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) - -template < - GMMA::Major tnspA, - GMMA::Major tnspB, - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -using SM90_64x112x16_F32BF16BF16_SS = SM90::GMMA::MMA_64x112x16_F32BF16BF16_SS; - -template -struct MMA_Traits> -{ - using ValTypeD = float; - using ValTypeA = bfloat16_t; - using ValTypeB = bfloat16_t; - using ValTypeC = float; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_112,_16>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 16>; - using BLayout = GMMA::ABLayout<112, 16>; - using CLayout = GMMA::CLayout_64x112; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) - -template < - GMMA::Major tnspA, - GMMA::Major tnspB, - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -using SM90_64x112x16_F32BF16BF16_RS = SM90::GMMA::MMA_64x112x16_F32BF16BF16_RS; - -template -struct MMA_Traits> -{ - using ValTypeD = float; - using ValTypeA = bfloat16_t; - using ValTypeB = bfloat16_t; - using ValTypeC = float; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_112,_16>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x16; - using BLayout = GMMA::ABLayout<112, 16>; - using CLayout = GMMA::CLayout_64x112; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - - -template < - GMMA::Major tnspA, - GMMA::Major tnspB, - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -using SM90_64x128x16_F32BF16BF16_SS = SM90::GMMA::MMA_64x128x16_F32BF16BF16_SS; - -template -struct MMA_Traits> -{ - using ValTypeD = float; - using ValTypeA = bfloat16_t; - using ValTypeB = bfloat16_t; - using ValTypeC = float; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_128,_16>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 16>; - using BLayout = GMMA::ABLayout<128, 16>; - using CLayout = GMMA::CLayout_64x128; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - - -template < - GMMA::Major tnspA, - GMMA::Major tnspB, - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -using SM90_64x128x16_F32BF16BF16_RS = SM90::GMMA::MMA_64x128x16_F32BF16BF16_RS; - -template -struct MMA_Traits> -{ - using ValTypeD = float; - using ValTypeA = bfloat16_t; - using ValTypeB = bfloat16_t; - using ValTypeC = float; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_128,_16>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x16; - using BLayout = GMMA::ABLayout<128, 16>; - using CLayout = GMMA::CLayout_64x128; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) - -template < - GMMA::Major tnspA, - GMMA::Major tnspB, - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -using SM90_64x144x16_F32BF16BF16_SS = SM90::GMMA::MMA_64x144x16_F32BF16BF16_SS; - -template -struct MMA_Traits> -{ - using ValTypeD = float; - using ValTypeA = bfloat16_t; - using ValTypeB = bfloat16_t; - using ValTypeC = float; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_144,_16>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 16>; - using BLayout = GMMA::ABLayout<144, 16>; - using CLayout = GMMA::CLayout_64x144; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) - -template < - GMMA::Major tnspA, - GMMA::Major tnspB, - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -using SM90_64x144x16_F32BF16BF16_RS = SM90::GMMA::MMA_64x144x16_F32BF16BF16_RS; - -template -struct MMA_Traits> -{ - using ValTypeD = float; - using ValTypeA = bfloat16_t; - using ValTypeB = bfloat16_t; - using ValTypeC = float; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_144,_16>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x16; - using BLayout = GMMA::ABLayout<144, 16>; - using CLayout = GMMA::CLayout_64x144; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) - -template < - GMMA::Major tnspA, - GMMA::Major tnspB, - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -using SM90_64x160x16_F32BF16BF16_SS = SM90::GMMA::MMA_64x160x16_F32BF16BF16_SS; - -template -struct MMA_Traits> -{ - using ValTypeD = float; - using ValTypeA = bfloat16_t; - using ValTypeB = bfloat16_t; - using ValTypeC = float; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_160,_16>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 16>; - using BLayout = GMMA::ABLayout<160, 16>; - using CLayout = GMMA::CLayout_64x160; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) - -template < - GMMA::Major tnspA, - GMMA::Major tnspB, - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -using SM90_64x160x16_F32BF16BF16_RS = SM90::GMMA::MMA_64x160x16_F32BF16BF16_RS; - -template -struct MMA_Traits> -{ - using ValTypeD = float; - using ValTypeA = bfloat16_t; - using ValTypeB = bfloat16_t; - using ValTypeC = float; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_160,_16>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x16; - using BLayout = GMMA::ABLayout<160, 16>; - using CLayout = GMMA::CLayout_64x160; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) - -template < - GMMA::Major tnspA, - GMMA::Major tnspB, - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -using SM90_64x176x16_F32BF16BF16_SS = SM90::GMMA::MMA_64x176x16_F32BF16BF16_SS; - -template -struct MMA_Traits> -{ - using ValTypeD = float; - using ValTypeA = bfloat16_t; - using ValTypeB = bfloat16_t; - using ValTypeC = float; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_176,_16>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 16>; - using BLayout = GMMA::ABLayout<176, 16>; - using CLayout = GMMA::CLayout_64x176; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) - -template < - GMMA::Major tnspA, - GMMA::Major tnspB, - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -using SM90_64x176x16_F32BF16BF16_RS = SM90::GMMA::MMA_64x176x16_F32BF16BF16_RS; - -template -struct MMA_Traits> -{ - using ValTypeD = float; - using ValTypeA = bfloat16_t; - using ValTypeB = bfloat16_t; - using ValTypeC = float; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_176,_16>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x16; - using BLayout = GMMA::ABLayout<176, 16>; - using CLayout = GMMA::CLayout_64x176; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - - -template < - GMMA::Major tnspA, - GMMA::Major tnspB, - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -using SM90_64x192x16_F32BF16BF16_SS = SM90::GMMA::MMA_64x192x16_F32BF16BF16_SS; - -template -struct MMA_Traits> -{ - using ValTypeD = float; - using ValTypeA = bfloat16_t; - using ValTypeB = bfloat16_t; - using ValTypeC = float; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_192,_16>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 16>; - using BLayout = GMMA::ABLayout<192, 16>; - using CLayout = GMMA::CLayout_64x192; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - - -template < - GMMA::Major tnspA, - GMMA::Major tnspB, - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -using SM90_64x192x16_F32BF16BF16_RS = SM90::GMMA::MMA_64x192x16_F32BF16BF16_RS; - -template -struct MMA_Traits> -{ - using ValTypeD = float; - using ValTypeA = bfloat16_t; - using ValTypeB = bfloat16_t; - using ValTypeC = float; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_192,_16>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x16; - using BLayout = GMMA::ABLayout<192, 16>; - using CLayout = GMMA::CLayout_64x192; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) - -template < - GMMA::Major tnspA, - GMMA::Major tnspB, - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -using SM90_64x208x16_F32BF16BF16_SS = SM90::GMMA::MMA_64x208x16_F32BF16BF16_SS; - -template -struct MMA_Traits> -{ - using ValTypeD = float; - using ValTypeA = bfloat16_t; - using ValTypeB = bfloat16_t; - using ValTypeC = float; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_208,_16>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 16>; - using BLayout = GMMA::ABLayout<208, 16>; - using CLayout = GMMA::CLayout_64x208; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) - -template < - GMMA::Major tnspA, - GMMA::Major tnspB, - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -using SM90_64x208x16_F32BF16BF16_RS = SM90::GMMA::MMA_64x208x16_F32BF16BF16_RS; - -template -struct MMA_Traits> -{ - using ValTypeD = float; - using ValTypeA = bfloat16_t; - using ValTypeB = bfloat16_t; - using ValTypeC = float; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_208,_16>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x16; - using BLayout = GMMA::ABLayout<208, 16>; - using CLayout = GMMA::CLayout_64x208; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) - -template < - GMMA::Major tnspA, - GMMA::Major tnspB, - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -using SM90_64x224x16_F32BF16BF16_SS = SM90::GMMA::MMA_64x224x16_F32BF16BF16_SS; - -template -struct MMA_Traits> -{ - using ValTypeD = float; - using ValTypeA = bfloat16_t; - using ValTypeB = bfloat16_t; - using ValTypeC = float; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_224,_16>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 16>; - using BLayout = GMMA::ABLayout<224, 16>; - using CLayout = GMMA::CLayout_64x224; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) - -template < - GMMA::Major tnspA, - GMMA::Major tnspB, - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -using SM90_64x224x16_F32BF16BF16_RS = SM90::GMMA::MMA_64x224x16_F32BF16BF16_RS; - -template -struct MMA_Traits> -{ - using ValTypeD = float; - using ValTypeA = bfloat16_t; - using ValTypeB = bfloat16_t; - using ValTypeC = float; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_224,_16>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x16; - using BLayout = GMMA::ABLayout<224, 16>; - using CLayout = GMMA::CLayout_64x224; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) - -template < - GMMA::Major tnspA, - GMMA::Major tnspB, - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -using SM90_64x240x16_F32BF16BF16_SS = SM90::GMMA::MMA_64x240x16_F32BF16BF16_SS; - -template -struct MMA_Traits> -{ - using ValTypeD = float; - using ValTypeA = bfloat16_t; - using ValTypeB = bfloat16_t; - using ValTypeC = float; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_240,_16>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 16>; - using BLayout = GMMA::ABLayout<240, 16>; - using CLayout = GMMA::CLayout_64x240; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) - -template < - GMMA::Major tnspA, - GMMA::Major tnspB, - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -using SM90_64x240x16_F32BF16BF16_RS = SM90::GMMA::MMA_64x240x16_F32BF16BF16_RS; - -template -struct MMA_Traits> -{ - using ValTypeD = float; - using ValTypeA = bfloat16_t; - using ValTypeB = bfloat16_t; - using ValTypeC = float; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_240,_16>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x16; - using BLayout = GMMA::ABLayout<240, 16>; - using CLayout = GMMA::CLayout_64x240; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - - -template < - GMMA::Major tnspA, - GMMA::Major tnspB, - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -using SM90_64x256x16_F32BF16BF16_SS = SM90::GMMA::MMA_64x256x16_F32BF16BF16_SS; - -template -struct MMA_Traits> -{ - using ValTypeD = float; - using ValTypeA = bfloat16_t; - using ValTypeB = bfloat16_t; - using ValTypeC = float; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_256,_16>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 16>; - using BLayout = GMMA::ABLayout<256, 16>; - using CLayout = GMMA::CLayout_64x256; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - - -template < - GMMA::Major tnspA, - GMMA::Major tnspB, - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -using SM90_64x256x16_F32BF16BF16_RS = SM90::GMMA::MMA_64x256x16_F32BF16BF16_RS; - -template -struct MMA_Traits> -{ - using ValTypeD = float; - using ValTypeA = bfloat16_t; - using ValTypeB = bfloat16_t; - using ValTypeC = float; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_256,_16>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x16; - using BLayout = GMMA::ABLayout<256, 16>; - using CLayout = GMMA::CLayout_64x256; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - - -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -using SM90_64x8x8_F32TF32TF32_SS_TN = SM90::GMMA::MMA_64x8x8_F32TF32TF32_SS_TN; - -template -struct MMA_Traits> -{ - using ValTypeD = float; - using ValTypeA = tfloat32_t; - using ValTypeB = tfloat32_t; - using ValTypeC = float; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_8,_8>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 8>; - using BLayout = GMMA::ABLayout< 8, 8>; - using CLayout = GMMA::CLayout_64x8; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - - -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -using SM90_64x8x8_F32TF32TF32_RS_TN = SM90::GMMA::MMA_64x8x8_F32TF32TF32_RS_TN; - -template -struct MMA_Traits> -{ - using ValTypeD = float; - using ValTypeA = tfloat32_t; - using ValTypeB = tfloat32_t; - using ValTypeC = float; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_8,_8>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x8; - using BLayout = GMMA::ABLayout< 8, 8>; - using CLayout = GMMA::CLayout_64x8; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - - -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -using SM90_64x16x8_F32TF32TF32_SS_TN = SM90::GMMA::MMA_64x16x8_F32TF32TF32_SS_TN; - -template -struct MMA_Traits> -{ - using ValTypeD = float; - using ValTypeA = tfloat32_t; - using ValTypeB = tfloat32_t; - using ValTypeC = float; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_16,_8>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 8>; - using BLayout = GMMA::ABLayout< 16, 8>; - using CLayout = GMMA::CLayout_64x16; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - - -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -using SM90_64x16x8_F32TF32TF32_RS_TN = SM90::GMMA::MMA_64x16x8_F32TF32TF32_RS_TN; - -template -struct MMA_Traits> -{ - using ValTypeD = float; - using ValTypeA = tfloat32_t; - using ValTypeB = tfloat32_t; - using ValTypeC = float; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_16,_8>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x8; - using BLayout = GMMA::ABLayout< 16, 8>; - using CLayout = GMMA::CLayout_64x16; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - - -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -using SM90_64x32x8_F32TF32TF32_SS_TN = SM90::GMMA::MMA_64x32x8_F32TF32TF32_SS_TN; - -template -struct MMA_Traits> -{ - using ValTypeD = float; - using ValTypeA = tfloat32_t; - using ValTypeB = tfloat32_t; - using ValTypeC = float; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_32,_8>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 8>; - using BLayout = GMMA::ABLayout< 32, 8>; - using CLayout = GMMA::CLayout_64x32; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - - -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -using SM90_64x32x8_F32TF32TF32_RS_TN = SM90::GMMA::MMA_64x32x8_F32TF32TF32_RS_TN; - -template -struct MMA_Traits> -{ - using ValTypeD = float; - using ValTypeA = tfloat32_t; - using ValTypeB = tfloat32_t; - using ValTypeC = float; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_32,_8>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x8; - using BLayout = GMMA::ABLayout< 32, 8>; - using CLayout = GMMA::CLayout_64x32; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) - -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -using SM90_64x48x8_F32TF32TF32_SS_TN = SM90::GMMA::MMA_64x48x8_F32TF32TF32_SS_TN; - -template -struct MMA_Traits> -{ - using ValTypeD = float; - using ValTypeA = tfloat32_t; - using ValTypeB = tfloat32_t; - using ValTypeC = float; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_48,_8>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 8>; - using BLayout = GMMA::ABLayout< 48, 8>; - using CLayout = GMMA::CLayout_64x48; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) - -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -using SM90_64x48x8_F32TF32TF32_RS_TN = SM90::GMMA::MMA_64x48x8_F32TF32TF32_RS_TN; - -template -struct MMA_Traits> -{ - using ValTypeD = float; - using ValTypeA = tfloat32_t; - using ValTypeB = tfloat32_t; - using ValTypeC = float; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_48,_8>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x8; - using BLayout = GMMA::ABLayout< 48, 8>; - using CLayout = GMMA::CLayout_64x48; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - - -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -using SM90_64x64x8_F32TF32TF32_SS_TN = SM90::GMMA::MMA_64x64x8_F32TF32TF32_SS_TN; - -template -struct MMA_Traits> -{ - using ValTypeD = float; - using ValTypeA = tfloat32_t; - using ValTypeB = tfloat32_t; - using ValTypeC = float; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_64,_8>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 8>; - using BLayout = GMMA::ABLayout< 64, 8>; - using CLayout = GMMA::CLayout_64x64; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - - -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -using SM90_64x64x8_F32TF32TF32_RS_TN = SM90::GMMA::MMA_64x64x8_F32TF32TF32_RS_TN; - -template -struct MMA_Traits> -{ - using ValTypeD = float; - using ValTypeA = tfloat32_t; - using ValTypeB = tfloat32_t; - using ValTypeC = float; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_64,_8>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x8; - using BLayout = GMMA::ABLayout< 64, 8>; - using CLayout = GMMA::CLayout_64x64; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) - -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -using SM90_64x80x8_F32TF32TF32_SS_TN = SM90::GMMA::MMA_64x80x8_F32TF32TF32_SS_TN; - -template -struct MMA_Traits> -{ - using ValTypeD = float; - using ValTypeA = tfloat32_t; - using ValTypeB = tfloat32_t; - using ValTypeC = float; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_80,_8>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 8>; - using BLayout = GMMA::ABLayout< 80, 8>; - using CLayout = GMMA::CLayout_64x80; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) - -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -using SM90_64x80x8_F32TF32TF32_RS_TN = SM90::GMMA::MMA_64x80x8_F32TF32TF32_RS_TN; - -template -struct MMA_Traits> -{ - using ValTypeD = float; - using ValTypeA = tfloat32_t; - using ValTypeB = tfloat32_t; - using ValTypeC = float; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_80,_8>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x8; - using BLayout = GMMA::ABLayout< 80, 8>; - using CLayout = GMMA::CLayout_64x80; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - - -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -using SM90_64x96x8_F32TF32TF32_SS_TN = SM90::GMMA::MMA_64x96x8_F32TF32TF32_SS_TN; - -template -struct MMA_Traits> -{ - using ValTypeD = float; - using ValTypeA = tfloat32_t; - using ValTypeB = tfloat32_t; - using ValTypeC = float; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_96,_8>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 8>; - using BLayout = GMMA::ABLayout< 96, 8>; - using CLayout = GMMA::CLayout_64x96; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - - -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -using SM90_64x96x8_F32TF32TF32_RS_TN = SM90::GMMA::MMA_64x96x8_F32TF32TF32_RS_TN; - -template -struct MMA_Traits> -{ - using ValTypeD = float; - using ValTypeA = tfloat32_t; - using ValTypeB = tfloat32_t; - using ValTypeC = float; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_96,_8>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x8; - using BLayout = GMMA::ABLayout< 96, 8>; - using CLayout = GMMA::CLayout_64x96; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) - -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -using SM90_64x112x8_F32TF32TF32_SS_TN = SM90::GMMA::MMA_64x112x8_F32TF32TF32_SS_TN; - -template -struct MMA_Traits> -{ - using ValTypeD = float; - using ValTypeA = tfloat32_t; - using ValTypeB = tfloat32_t; - using ValTypeC = float; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_112,_8>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 8>; - using BLayout = GMMA::ABLayout<112, 8>; - using CLayout = GMMA::CLayout_64x112; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) - -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -using SM90_64x112x8_F32TF32TF32_RS_TN = SM90::GMMA::MMA_64x112x8_F32TF32TF32_RS_TN; - -template -struct MMA_Traits> -{ - using ValTypeD = float; - using ValTypeA = tfloat32_t; - using ValTypeB = tfloat32_t; - using ValTypeC = float; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_112,_8>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x8; - using BLayout = GMMA::ABLayout<112, 8>; - using CLayout = GMMA::CLayout_64x112; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - - -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -using SM90_64x128x8_F32TF32TF32_SS_TN = SM90::GMMA::MMA_64x128x8_F32TF32TF32_SS_TN; - -template -struct MMA_Traits> -{ - using ValTypeD = float; - using ValTypeA = tfloat32_t; - using ValTypeB = tfloat32_t; - using ValTypeC = float; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_128,_8>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 8>; - using BLayout = GMMA::ABLayout<128, 8>; - using CLayout = GMMA::CLayout_64x128; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - - -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -using SM90_64x128x8_F32TF32TF32_RS_TN = SM90::GMMA::MMA_64x128x8_F32TF32TF32_RS_TN; - -template -struct MMA_Traits> -{ - using ValTypeD = float; - using ValTypeA = tfloat32_t; - using ValTypeB = tfloat32_t; - using ValTypeC = float; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_128,_8>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x8; - using BLayout = GMMA::ABLayout<128, 8>; - using CLayout = GMMA::CLayout_64x128; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) - -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -using SM90_64x144x8_F32TF32TF32_SS_TN = SM90::GMMA::MMA_64x144x8_F32TF32TF32_SS_TN; - -template -struct MMA_Traits> -{ - using ValTypeD = float; - using ValTypeA = tfloat32_t; - using ValTypeB = tfloat32_t; - using ValTypeC = float; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_144,_8>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 8>; - using BLayout = GMMA::ABLayout<144, 8>; - using CLayout = GMMA::CLayout_64x144; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) - -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -using SM90_64x144x8_F32TF32TF32_RS_TN = SM90::GMMA::MMA_64x144x8_F32TF32TF32_RS_TN; - -template -struct MMA_Traits> -{ - using ValTypeD = float; - using ValTypeA = tfloat32_t; - using ValTypeB = tfloat32_t; - using ValTypeC = float; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_144,_8>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x8; - using BLayout = GMMA::ABLayout<144, 8>; - using CLayout = GMMA::CLayout_64x144; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) - -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -using SM90_64x160x8_F32TF32TF32_SS_TN = SM90::GMMA::MMA_64x160x8_F32TF32TF32_SS_TN; - -template -struct MMA_Traits> -{ - using ValTypeD = float; - using ValTypeA = tfloat32_t; - using ValTypeB = tfloat32_t; - using ValTypeC = float; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_160,_8>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 8>; - using BLayout = GMMA::ABLayout<160, 8>; - using CLayout = GMMA::CLayout_64x160; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) - -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -using SM90_64x160x8_F32TF32TF32_RS_TN = SM90::GMMA::MMA_64x160x8_F32TF32TF32_RS_TN; - -template -struct MMA_Traits> -{ - using ValTypeD = float; - using ValTypeA = tfloat32_t; - using ValTypeB = tfloat32_t; - using ValTypeC = float; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_160,_8>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x8; - using BLayout = GMMA::ABLayout<160, 8>; - using CLayout = GMMA::CLayout_64x160; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) - -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -using SM90_64x176x8_F32TF32TF32_SS_TN = SM90::GMMA::MMA_64x176x8_F32TF32TF32_SS_TN; - -template -struct MMA_Traits> -{ - using ValTypeD = float; - using ValTypeA = tfloat32_t; - using ValTypeB = tfloat32_t; - using ValTypeC = float; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_176,_8>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 8>; - using BLayout = GMMA::ABLayout<176, 8>; - using CLayout = GMMA::CLayout_64x176; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) - -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -using SM90_64x176x8_F32TF32TF32_RS_TN = SM90::GMMA::MMA_64x176x8_F32TF32TF32_RS_TN; - -template -struct MMA_Traits> -{ - using ValTypeD = float; - using ValTypeA = tfloat32_t; - using ValTypeB = tfloat32_t; - using ValTypeC = float; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_176,_8>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x8; - using BLayout = GMMA::ABLayout<176, 8>; - using CLayout = GMMA::CLayout_64x176; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - - -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -using SM90_64x192x8_F32TF32TF32_SS_TN = SM90::GMMA::MMA_64x192x8_F32TF32TF32_SS_TN; - -template -struct MMA_Traits> -{ - using ValTypeD = float; - using ValTypeA = tfloat32_t; - using ValTypeB = tfloat32_t; - using ValTypeC = float; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_192,_8>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 8>; - using BLayout = GMMA::ABLayout<192, 8>; - using CLayout = GMMA::CLayout_64x192; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - - -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -using SM90_64x192x8_F32TF32TF32_RS_TN = SM90::GMMA::MMA_64x192x8_F32TF32TF32_RS_TN; - -template -struct MMA_Traits> -{ - using ValTypeD = float; - using ValTypeA = tfloat32_t; - using ValTypeB = tfloat32_t; - using ValTypeC = float; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_192,_8>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x8; - using BLayout = GMMA::ABLayout<192, 8>; - using CLayout = GMMA::CLayout_64x192; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) - -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -using SM90_64x208x8_F32TF32TF32_SS_TN = SM90::GMMA::MMA_64x208x8_F32TF32TF32_SS_TN; - -template -struct MMA_Traits> -{ - using ValTypeD = float; - using ValTypeA = tfloat32_t; - using ValTypeB = tfloat32_t; - using ValTypeC = float; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_208,_8>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 8>; - using BLayout = GMMA::ABLayout<208, 8>; - using CLayout = GMMA::CLayout_64x208; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) - -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -using SM90_64x208x8_F32TF32TF32_RS_TN = SM90::GMMA::MMA_64x208x8_F32TF32TF32_RS_TN; - -template -struct MMA_Traits> -{ - using ValTypeD = float; - using ValTypeA = tfloat32_t; - using ValTypeB = tfloat32_t; - using ValTypeC = float; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_208,_8>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x8; - using BLayout = GMMA::ABLayout<208, 8>; - using CLayout = GMMA::CLayout_64x208; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) - -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -using SM90_64x224x8_F32TF32TF32_SS_TN = SM90::GMMA::MMA_64x224x8_F32TF32TF32_SS_TN; - -template -struct MMA_Traits> -{ - using ValTypeD = float; - using ValTypeA = tfloat32_t; - using ValTypeB = tfloat32_t; - using ValTypeC = float; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_224,_8>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 8>; - using BLayout = GMMA::ABLayout<224, 8>; - using CLayout = GMMA::CLayout_64x224; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) - -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -using SM90_64x224x8_F32TF32TF32_RS_TN = SM90::GMMA::MMA_64x224x8_F32TF32TF32_RS_TN; - -template -struct MMA_Traits> -{ - using ValTypeD = float; - using ValTypeA = tfloat32_t; - using ValTypeB = tfloat32_t; - using ValTypeC = float; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_224,_8>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x8; - using BLayout = GMMA::ABLayout<224, 8>; - using CLayout = GMMA::CLayout_64x224; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) - -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -using SM90_64x240x8_F32TF32TF32_SS_TN = SM90::GMMA::MMA_64x240x8_F32TF32TF32_SS_TN; - -template -struct MMA_Traits> -{ - using ValTypeD = float; - using ValTypeA = tfloat32_t; - using ValTypeB = tfloat32_t; - using ValTypeC = float; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_240,_8>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 8>; - using BLayout = GMMA::ABLayout<240, 8>; - using CLayout = GMMA::CLayout_64x240; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) - -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -using SM90_64x240x8_F32TF32TF32_RS_TN = SM90::GMMA::MMA_64x240x8_F32TF32TF32_RS_TN; - -template -struct MMA_Traits> -{ - using ValTypeD = float; - using ValTypeA = tfloat32_t; - using ValTypeB = tfloat32_t; - using ValTypeC = float; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_240,_8>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x8; - using BLayout = GMMA::ABLayout<240, 8>; - using CLayout = GMMA::CLayout_64x240; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - - -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -using SM90_64x256x8_F32TF32TF32_SS_TN = SM90::GMMA::MMA_64x256x8_F32TF32TF32_SS_TN; - -template -struct MMA_Traits> -{ - using ValTypeD = float; - using ValTypeA = tfloat32_t; - using ValTypeB = tfloat32_t; - using ValTypeC = float; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_256,_8>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 8>; - using BLayout = GMMA::ABLayout<256, 8>; - using CLayout = GMMA::CLayout_64x256; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - - -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -using SM90_64x256x8_F32TF32TF32_RS_TN = SM90::GMMA::MMA_64x256x8_F32TF32TF32_RS_TN; - -template -struct MMA_Traits> -{ - using ValTypeD = float; - using ValTypeA = tfloat32_t; - using ValTypeB = tfloat32_t; - using ValTypeC = float; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_256,_8>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x8; - using BLayout = GMMA::ABLayout<256, 8>; - using CLayout = GMMA::CLayout_64x256; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - - - -using SM90_64x8x32_S32S8S8_SS_TN = SM90::GMMA::MMA_64x8x32_S32S8S8_SS_TN; - -template <> -struct MMA_Traits -{ - using ValTypeD = int32_t; - using ValTypeA = int8_t; - using ValTypeB = int8_t; - using ValTypeC = int32_t; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_8,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 32>; - using BLayout = GMMA::ABLayout< 8, 32>; - using CLayout = GMMA::CLayout_64x8; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - - - -using SM90_64x8x32_S32S8S8_SS_TN_SATURATE = SM90::GMMA::MMA_64x8x32_S32S8S8_SS_TN_SATURATE; - -template <> -struct MMA_Traits -{ - using ValTypeD = int32_t; - using ValTypeA = int8_t; - using ValTypeB = int8_t; - using ValTypeC = int32_t; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_8,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 32>; - using BLayout = GMMA::ABLayout< 8, 32>; - using CLayout = GMMA::CLayout_64x8; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - - - -using SM90_64x16x32_S32S8S8_SS_TN = SM90::GMMA::MMA_64x16x32_S32S8S8_SS_TN; - -template <> -struct MMA_Traits -{ - using ValTypeD = int32_t; - using ValTypeA = int8_t; - using ValTypeB = int8_t; - using ValTypeC = int32_t; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_16,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 32>; - using BLayout = GMMA::ABLayout< 16, 32>; - using CLayout = GMMA::CLayout_64x16; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - - - -using SM90_64x16x32_S32S8S8_SS_TN_SATURATE = SM90::GMMA::MMA_64x16x32_S32S8S8_SS_TN_SATURATE; - -template <> -struct MMA_Traits -{ - using ValTypeD = int32_t; - using ValTypeA = int8_t; - using ValTypeB = int8_t; - using ValTypeC = int32_t; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_16,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 32>; - using BLayout = GMMA::ABLayout< 16, 32>; - using CLayout = GMMA::CLayout_64x16; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - - - -using SM90_64x32x32_S32S8S8_SS_TN = SM90::GMMA::MMA_64x32x32_S32S8S8_SS_TN; - -template <> -struct MMA_Traits -{ - using ValTypeD = int32_t; - using ValTypeA = int8_t; - using ValTypeB = int8_t; - using ValTypeC = int32_t; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_32,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 32>; - using BLayout = GMMA::ABLayout< 32, 32>; - using CLayout = GMMA::CLayout_64x32; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - - - -using SM90_64x32x32_S32S8S8_SS_TN_SATURATE = SM90::GMMA::MMA_64x32x32_S32S8S8_SS_TN_SATURATE; - -template <> -struct MMA_Traits -{ - using ValTypeD = int32_t; - using ValTypeA = int8_t; - using ValTypeB = int8_t; - using ValTypeC = int32_t; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_32,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 32>; - using BLayout = GMMA::ABLayout< 32, 32>; - using CLayout = GMMA::CLayout_64x32; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) - - -using SM90_64x48x32_S32S8S8_SS_TN = SM90::GMMA::MMA_64x48x32_S32S8S8_SS_TN; - -template <> -struct MMA_Traits -{ - using ValTypeD = int32_t; - using ValTypeA = int8_t; - using ValTypeB = int8_t; - using ValTypeC = int32_t; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_48,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 32>; - using BLayout = GMMA::ABLayout< 48, 32>; - using CLayout = GMMA::CLayout_64x48; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) - - -using SM90_64x48x32_S32S8S8_SS_TN_SATURATE = SM90::GMMA::MMA_64x48x32_S32S8S8_SS_TN_SATURATE; - -template <> -struct MMA_Traits -{ - using ValTypeD = int32_t; - using ValTypeA = int8_t; - using ValTypeB = int8_t; - using ValTypeC = int32_t; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_48,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 32>; - using BLayout = GMMA::ABLayout< 48, 32>; - using CLayout = GMMA::CLayout_64x48; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - - - -using SM90_64x64x32_S32S8S8_SS_TN = SM90::GMMA::MMA_64x64x32_S32S8S8_SS_TN; - -template <> -struct MMA_Traits -{ - using ValTypeD = int32_t; - using ValTypeA = int8_t; - using ValTypeB = int8_t; - using ValTypeC = int32_t; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_64,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 32>; - using BLayout = GMMA::ABLayout< 64, 32>; - using CLayout = GMMA::CLayout_64x64; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - - - -using SM90_64x64x32_S32S8S8_SS_TN_SATURATE = SM90::GMMA::MMA_64x64x32_S32S8S8_SS_TN_SATURATE; - -template <> -struct MMA_Traits -{ - using ValTypeD = int32_t; - using ValTypeA = int8_t; - using ValTypeB = int8_t; - using ValTypeC = int32_t; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_64,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 32>; - using BLayout = GMMA::ABLayout< 64, 32>; - using CLayout = GMMA::CLayout_64x64; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) - - -using SM90_64x80x32_S32S8S8_SS_TN = SM90::GMMA::MMA_64x80x32_S32S8S8_SS_TN; - -template <> -struct MMA_Traits -{ - using ValTypeD = int32_t; - using ValTypeA = int8_t; - using ValTypeB = int8_t; - using ValTypeC = int32_t; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_80,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 32>; - using BLayout = GMMA::ABLayout< 80, 32>; - using CLayout = GMMA::CLayout_64x80; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) - - -using SM90_64x80x32_S32S8S8_SS_TN_SATURATE = SM90::GMMA::MMA_64x80x32_S32S8S8_SS_TN_SATURATE; - -template <> -struct MMA_Traits -{ - using ValTypeD = int32_t; - using ValTypeA = int8_t; - using ValTypeB = int8_t; - using ValTypeC = int32_t; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_80,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 32>; - using BLayout = GMMA::ABLayout< 80, 32>; - using CLayout = GMMA::CLayout_64x80; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - - - -using SM90_64x96x32_S32S8S8_SS_TN = SM90::GMMA::MMA_64x96x32_S32S8S8_SS_TN; - -template <> -struct MMA_Traits -{ - using ValTypeD = int32_t; - using ValTypeA = int8_t; - using ValTypeB = int8_t; - using ValTypeC = int32_t; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_96,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 32>; - using BLayout = GMMA::ABLayout< 96, 32>; - using CLayout = GMMA::CLayout_64x96; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - - - -using SM90_64x96x32_S32S8S8_SS_TN_SATURATE = SM90::GMMA::MMA_64x96x32_S32S8S8_SS_TN_SATURATE; - -template <> -struct MMA_Traits -{ - using ValTypeD = int32_t; - using ValTypeA = int8_t; - using ValTypeB = int8_t; - using ValTypeC = int32_t; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_96,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 32>; - using BLayout = GMMA::ABLayout< 96, 32>; - using CLayout = GMMA::CLayout_64x96; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) - - -using SM90_64x112x32_S32S8S8_SS_TN = SM90::GMMA::MMA_64x112x32_S32S8S8_SS_TN; - -template <> -struct MMA_Traits -{ - using ValTypeD = int32_t; - using ValTypeA = int8_t; - using ValTypeB = int8_t; - using ValTypeC = int32_t; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_112,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 32>; - using BLayout = GMMA::ABLayout<112, 32>; - using CLayout = GMMA::CLayout_64x112; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) - - -using SM90_64x112x32_S32S8S8_SS_TN_SATURATE = SM90::GMMA::MMA_64x112x32_S32S8S8_SS_TN_SATURATE; - -template <> -struct MMA_Traits -{ - using ValTypeD = int32_t; - using ValTypeA = int8_t; - using ValTypeB = int8_t; - using ValTypeC = int32_t; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_112,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 32>; - using BLayout = GMMA::ABLayout<112, 32>; - using CLayout = GMMA::CLayout_64x112; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - - - -using SM90_64x128x32_S32S8S8_SS_TN = SM90::GMMA::MMA_64x128x32_S32S8S8_SS_TN; - -template <> -struct MMA_Traits -{ - using ValTypeD = int32_t; - using ValTypeA = int8_t; - using ValTypeB = int8_t; - using ValTypeC = int32_t; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_128,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 32>; - using BLayout = GMMA::ABLayout<128, 32>; - using CLayout = GMMA::CLayout_64x128; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - - - -using SM90_64x128x32_S32S8S8_SS_TN_SATURATE = SM90::GMMA::MMA_64x128x32_S32S8S8_SS_TN_SATURATE; - -template <> -struct MMA_Traits -{ - using ValTypeD = int32_t; - using ValTypeA = int8_t; - using ValTypeB = int8_t; - using ValTypeC = int32_t; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_128,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 32>; - using BLayout = GMMA::ABLayout<128, 32>; - using CLayout = GMMA::CLayout_64x128; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) - - -using SM90_64x144x32_S32S8S8_SS_TN = SM90::GMMA::MMA_64x144x32_S32S8S8_SS_TN; - -template <> -struct MMA_Traits -{ - using ValTypeD = int32_t; - using ValTypeA = int8_t; - using ValTypeB = int8_t; - using ValTypeC = int32_t; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_144,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 32>; - using BLayout = GMMA::ABLayout<144, 32>; - using CLayout = GMMA::CLayout_64x144; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) - - -using SM90_64x144x32_S32S8S8_SS_TN_SATURATE = SM90::GMMA::MMA_64x144x32_S32S8S8_SS_TN_SATURATE; - -template <> -struct MMA_Traits -{ - using ValTypeD = int32_t; - using ValTypeA = int8_t; - using ValTypeB = int8_t; - using ValTypeC = int32_t; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_144,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 32>; - using BLayout = GMMA::ABLayout<144, 32>; - using CLayout = GMMA::CLayout_64x144; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) - - -using SM90_64x160x32_S32S8S8_SS_TN = SM90::GMMA::MMA_64x160x32_S32S8S8_SS_TN; - -template <> -struct MMA_Traits -{ - using ValTypeD = int32_t; - using ValTypeA = int8_t; - using ValTypeB = int8_t; - using ValTypeC = int32_t; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_160,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 32>; - using BLayout = GMMA::ABLayout<160, 32>; - using CLayout = GMMA::CLayout_64x160; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) - - -using SM90_64x160x32_S32S8S8_SS_TN_SATURATE = SM90::GMMA::MMA_64x160x32_S32S8S8_SS_TN_SATURATE; - -template <> -struct MMA_Traits -{ - using ValTypeD = int32_t; - using ValTypeA = int8_t; - using ValTypeB = int8_t; - using ValTypeC = int32_t; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_160,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 32>; - using BLayout = GMMA::ABLayout<160, 32>; - using CLayout = GMMA::CLayout_64x160; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) - - -using SM90_64x176x32_S32S8S8_SS_TN = SM90::GMMA::MMA_64x176x32_S32S8S8_SS_TN; - -template <> -struct MMA_Traits -{ - using ValTypeD = int32_t; - using ValTypeA = int8_t; - using ValTypeB = int8_t; - using ValTypeC = int32_t; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_176,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 32>; - using BLayout = GMMA::ABLayout<176, 32>; - using CLayout = GMMA::CLayout_64x176; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) - - -using SM90_64x176x32_S32S8S8_SS_TN_SATURATE = SM90::GMMA::MMA_64x176x32_S32S8S8_SS_TN_SATURATE; - -template <> -struct MMA_Traits -{ - using ValTypeD = int32_t; - using ValTypeA = int8_t; - using ValTypeB = int8_t; - using ValTypeC = int32_t; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_176,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 32>; - using BLayout = GMMA::ABLayout<176, 32>; - using CLayout = GMMA::CLayout_64x176; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - - - -using SM90_64x192x32_S32S8S8_SS_TN = SM90::GMMA::MMA_64x192x32_S32S8S8_SS_TN; - -template <> -struct MMA_Traits -{ - using ValTypeD = int32_t; - using ValTypeA = int8_t; - using ValTypeB = int8_t; - using ValTypeC = int32_t; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_192,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 32>; - using BLayout = GMMA::ABLayout<192, 32>; - using CLayout = GMMA::CLayout_64x192; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - - - -using SM90_64x192x32_S32S8S8_SS_TN_SATURATE = SM90::GMMA::MMA_64x192x32_S32S8S8_SS_TN_SATURATE; - -template <> -struct MMA_Traits -{ - using ValTypeD = int32_t; - using ValTypeA = int8_t; - using ValTypeB = int8_t; - using ValTypeC = int32_t; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_192,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 32>; - using BLayout = GMMA::ABLayout<192, 32>; - using CLayout = GMMA::CLayout_64x192; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) - - -using SM90_64x208x32_S32S8S8_SS_TN = SM90::GMMA::MMA_64x208x32_S32S8S8_SS_TN; - -template <> -struct MMA_Traits -{ - using ValTypeD = int32_t; - using ValTypeA = int8_t; - using ValTypeB = int8_t; - using ValTypeC = int32_t; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_208,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 32>; - using BLayout = GMMA::ABLayout<208, 32>; - using CLayout = GMMA::CLayout_64x208; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) - - -using SM90_64x208x32_S32S8S8_SS_TN_SATURATE = SM90::GMMA::MMA_64x208x32_S32S8S8_SS_TN_SATURATE; - -template <> -struct MMA_Traits -{ - using ValTypeD = int32_t; - using ValTypeA = int8_t; - using ValTypeB = int8_t; - using ValTypeC = int32_t; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_208,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 32>; - using BLayout = GMMA::ABLayout<208, 32>; - using CLayout = GMMA::CLayout_64x208; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) - - -using SM90_64x224x32_S32S8S8_SS_TN = SM90::GMMA::MMA_64x224x32_S32S8S8_SS_TN; - -template <> -struct MMA_Traits -{ - using ValTypeD = int32_t; - using ValTypeA = int8_t; - using ValTypeB = int8_t; - using ValTypeC = int32_t; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_224,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 32>; - using BLayout = GMMA::ABLayout<224, 32>; - using CLayout = GMMA::CLayout_64x224; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) - - -using SM90_64x224x32_S32S8S8_SS_TN_SATURATE = SM90::GMMA::MMA_64x224x32_S32S8S8_SS_TN_SATURATE; - -template <> -struct MMA_Traits -{ - using ValTypeD = int32_t; - using ValTypeA = int8_t; - using ValTypeB = int8_t; - using ValTypeC = int32_t; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_224,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 32>; - using BLayout = GMMA::ABLayout<224, 32>; - using CLayout = GMMA::CLayout_64x224; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) - - -using SM90_64x240x32_S32S8S8_SS_TN = SM90::GMMA::MMA_64x240x32_S32S8S8_SS_TN; - -template <> -struct MMA_Traits -{ - using ValTypeD = int32_t; - using ValTypeA = int8_t; - using ValTypeB = int8_t; - using ValTypeC = int32_t; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_240,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 32>; - using BLayout = GMMA::ABLayout<240, 32>; - using CLayout = GMMA::CLayout_64x240; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) - - -using SM90_64x240x32_S32S8S8_SS_TN_SATURATE = SM90::GMMA::MMA_64x240x32_S32S8S8_SS_TN_SATURATE; - -template <> -struct MMA_Traits -{ - using ValTypeD = int32_t; - using ValTypeA = int8_t; - using ValTypeB = int8_t; - using ValTypeC = int32_t; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_240,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 32>; - using BLayout = GMMA::ABLayout<240, 32>; - using CLayout = GMMA::CLayout_64x240; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - - - -using SM90_64x256x32_S32S8S8_SS_TN = SM90::GMMA::MMA_64x256x32_S32S8S8_SS_TN; - -template <> -struct MMA_Traits -{ - using ValTypeD = int32_t; - using ValTypeA = int8_t; - using ValTypeB = int8_t; - using ValTypeC = int32_t; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_256,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 32>; - using BLayout = GMMA::ABLayout<256, 32>; - using CLayout = GMMA::CLayout_64x256; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - - - -using SM90_64x256x32_S32S8S8_SS_TN_SATURATE = SM90::GMMA::MMA_64x256x32_S32S8S8_SS_TN_SATURATE; - -template <> -struct MMA_Traits -{ - using ValTypeD = int32_t; - using ValTypeA = int8_t; - using ValTypeB = int8_t; - using ValTypeC = int32_t; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_256,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 32>; - using BLayout = GMMA::ABLayout<256, 32>; - using CLayout = GMMA::CLayout_64x256; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - - - -using SM90_64x8x32_S32S8S8_RS_TN = SM90::GMMA::MMA_64x8x32_S32S8S8_RS_TN; - -template <> -struct MMA_Traits -{ - using ValTypeD = int32_t; - using ValTypeA = int8_t; - using ValTypeB = int8_t; - using ValTypeC = int32_t; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_8,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x32; - using BLayout = GMMA::ABLayout< 8, 32>; - using CLayout = GMMA::CLayout_64x8; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - - - -using SM90_64x8x32_S32S8S8_RS_TN_SATURATE = SM90::GMMA::MMA_64x8x32_S32S8S8_RS_TN_SATURATE; - -template <> -struct MMA_Traits -{ - using ValTypeD = int32_t; - using ValTypeA = int8_t; - using ValTypeB = int8_t; - using ValTypeC = int32_t; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_8,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x32; - using BLayout = GMMA::ABLayout< 8, 32>; - using CLayout = GMMA::CLayout_64x8; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - - - -using SM90_64x16x32_S32S8S8_RS_TN = SM90::GMMA::MMA_64x16x32_S32S8S8_RS_TN; - -template <> -struct MMA_Traits -{ - using ValTypeD = int32_t; - using ValTypeA = int8_t; - using ValTypeB = int8_t; - using ValTypeC = int32_t; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_16,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x32; - using BLayout = GMMA::ABLayout< 16, 32>; - using CLayout = GMMA::CLayout_64x16; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - - - -using SM90_64x16x32_S32S8S8_RS_TN_SATURATE = SM90::GMMA::MMA_64x16x32_S32S8S8_RS_TN_SATURATE; - -template <> -struct MMA_Traits -{ - using ValTypeD = int32_t; - using ValTypeA = int8_t; - using ValTypeB = int8_t; - using ValTypeC = int32_t; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_16,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x32; - using BLayout = GMMA::ABLayout< 16, 32>; - using CLayout = GMMA::CLayout_64x16; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - - - -using SM90_64x32x32_S32S8S8_RS_TN = SM90::GMMA::MMA_64x32x32_S32S8S8_RS_TN; - -template <> -struct MMA_Traits -{ - using ValTypeD = int32_t; - using ValTypeA = int8_t; - using ValTypeB = int8_t; - using ValTypeC = int32_t; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_32,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x32; - using BLayout = GMMA::ABLayout< 32, 32>; - using CLayout = GMMA::CLayout_64x32; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - - - -using SM90_64x32x32_S32S8S8_RS_TN_SATURATE = SM90::GMMA::MMA_64x32x32_S32S8S8_RS_TN_SATURATE; - -template <> -struct MMA_Traits -{ - using ValTypeD = int32_t; - using ValTypeA = int8_t; - using ValTypeB = int8_t; - using ValTypeC = int32_t; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_32,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x32; - using BLayout = GMMA::ABLayout< 32, 32>; - using CLayout = GMMA::CLayout_64x32; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) - - -using SM90_64x48x32_S32S8S8_RS_TN = SM90::GMMA::MMA_64x48x32_S32S8S8_RS_TN; - -template <> -struct MMA_Traits -{ - using ValTypeD = int32_t; - using ValTypeA = int8_t; - using ValTypeB = int8_t; - using ValTypeC = int32_t; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_48,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x32; - using BLayout = GMMA::ABLayout< 48, 32>; - using CLayout = GMMA::CLayout_64x48; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) - - -using SM90_64x48x32_S32S8S8_RS_TN_SATURATE = SM90::GMMA::MMA_64x48x32_S32S8S8_RS_TN_SATURATE; - -template <> -struct MMA_Traits -{ - using ValTypeD = int32_t; - using ValTypeA = int8_t; - using ValTypeB = int8_t; - using ValTypeC = int32_t; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_48,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x32; - using BLayout = GMMA::ABLayout< 48, 32>; - using CLayout = GMMA::CLayout_64x48; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - - - -using SM90_64x64x32_S32S8S8_RS_TN = SM90::GMMA::MMA_64x64x32_S32S8S8_RS_TN; - -template <> -struct MMA_Traits -{ - using ValTypeD = int32_t; - using ValTypeA = int8_t; - using ValTypeB = int8_t; - using ValTypeC = int32_t; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_64,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x32; - using BLayout = GMMA::ABLayout< 64, 32>; - using CLayout = GMMA::CLayout_64x64; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - - - -using SM90_64x64x32_S32S8S8_RS_TN_SATURATE = SM90::GMMA::MMA_64x64x32_S32S8S8_RS_TN_SATURATE; - -template <> -struct MMA_Traits -{ - using ValTypeD = int32_t; - using ValTypeA = int8_t; - using ValTypeB = int8_t; - using ValTypeC = int32_t; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_64,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x32; - using BLayout = GMMA::ABLayout< 64, 32>; - using CLayout = GMMA::CLayout_64x64; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) - - -using SM90_64x80x32_S32S8S8_RS_TN = SM90::GMMA::MMA_64x80x32_S32S8S8_RS_TN; - -template <> -struct MMA_Traits -{ - using ValTypeD = int32_t; - using ValTypeA = int8_t; - using ValTypeB = int8_t; - using ValTypeC = int32_t; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_80,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x32; - using BLayout = GMMA::ABLayout< 80, 32>; - using CLayout = GMMA::CLayout_64x80; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) - - -using SM90_64x80x32_S32S8S8_RS_TN_SATURATE = SM90::GMMA::MMA_64x80x32_S32S8S8_RS_TN_SATURATE; - -template <> -struct MMA_Traits -{ - using ValTypeD = int32_t; - using ValTypeA = int8_t; - using ValTypeB = int8_t; - using ValTypeC = int32_t; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_80,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x32; - using BLayout = GMMA::ABLayout< 80, 32>; - using CLayout = GMMA::CLayout_64x80; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - - - -using SM90_64x96x32_S32S8S8_RS_TN = SM90::GMMA::MMA_64x96x32_S32S8S8_RS_TN; - -template <> -struct MMA_Traits -{ - using ValTypeD = int32_t; - using ValTypeA = int8_t; - using ValTypeB = int8_t; - using ValTypeC = int32_t; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_96,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x32; - using BLayout = GMMA::ABLayout< 96, 32>; - using CLayout = GMMA::CLayout_64x96; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - - - -using SM90_64x96x32_S32S8S8_RS_TN_SATURATE = SM90::GMMA::MMA_64x96x32_S32S8S8_RS_TN_SATURATE; - -template <> -struct MMA_Traits -{ - using ValTypeD = int32_t; - using ValTypeA = int8_t; - using ValTypeB = int8_t; - using ValTypeC = int32_t; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_96,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x32; - using BLayout = GMMA::ABLayout< 96, 32>; - using CLayout = GMMA::CLayout_64x96; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) - - -using SM90_64x112x32_S32S8S8_RS_TN = SM90::GMMA::MMA_64x112x32_S32S8S8_RS_TN; - -template <> -struct MMA_Traits -{ - using ValTypeD = int32_t; - using ValTypeA = int8_t; - using ValTypeB = int8_t; - using ValTypeC = int32_t; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_112,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x32; - using BLayout = GMMA::ABLayout<112, 32>; - using CLayout = GMMA::CLayout_64x112; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) - - -using SM90_64x112x32_S32S8S8_RS_TN_SATURATE = SM90::GMMA::MMA_64x112x32_S32S8S8_RS_TN_SATURATE; - -template <> -struct MMA_Traits -{ - using ValTypeD = int32_t; - using ValTypeA = int8_t; - using ValTypeB = int8_t; - using ValTypeC = int32_t; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_112,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x32; - using BLayout = GMMA::ABLayout<112, 32>; - using CLayout = GMMA::CLayout_64x112; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - - - -using SM90_64x128x32_S32S8S8_RS_TN = SM90::GMMA::MMA_64x128x32_S32S8S8_RS_TN; - -template <> -struct MMA_Traits -{ - using ValTypeD = int32_t; - using ValTypeA = int8_t; - using ValTypeB = int8_t; - using ValTypeC = int32_t; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_128,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x32; - using BLayout = GMMA::ABLayout<128, 32>; - using CLayout = GMMA::CLayout_64x128; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - - - -using SM90_64x128x32_S32S8S8_RS_TN_SATURATE = SM90::GMMA::MMA_64x128x32_S32S8S8_RS_TN_SATURATE; - -template <> -struct MMA_Traits -{ - using ValTypeD = int32_t; - using ValTypeA = int8_t; - using ValTypeB = int8_t; - using ValTypeC = int32_t; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_128,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x32; - using BLayout = GMMA::ABLayout<128, 32>; - using CLayout = GMMA::CLayout_64x128; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) - - -using SM90_64x144x32_S32S8S8_RS_TN = SM90::GMMA::MMA_64x144x32_S32S8S8_RS_TN; - -template <> -struct MMA_Traits -{ - using ValTypeD = int32_t; - using ValTypeA = int8_t; - using ValTypeB = int8_t; - using ValTypeC = int32_t; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_144,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x32; - using BLayout = GMMA::ABLayout<144, 32>; - using CLayout = GMMA::CLayout_64x144; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) - - -using SM90_64x144x32_S32S8S8_RS_TN_SATURATE = SM90::GMMA::MMA_64x144x32_S32S8S8_RS_TN_SATURATE; - -template <> -struct MMA_Traits -{ - using ValTypeD = int32_t; - using ValTypeA = int8_t; - using ValTypeB = int8_t; - using ValTypeC = int32_t; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_144,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x32; - using BLayout = GMMA::ABLayout<144, 32>; - using CLayout = GMMA::CLayout_64x144; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) - - -using SM90_64x160x32_S32S8S8_RS_TN = SM90::GMMA::MMA_64x160x32_S32S8S8_RS_TN; - -template <> -struct MMA_Traits -{ - using ValTypeD = int32_t; - using ValTypeA = int8_t; - using ValTypeB = int8_t; - using ValTypeC = int32_t; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_160,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x32; - using BLayout = GMMA::ABLayout<160, 32>; - using CLayout = GMMA::CLayout_64x160; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) - - -using SM90_64x160x32_S32S8S8_RS_TN_SATURATE = SM90::GMMA::MMA_64x160x32_S32S8S8_RS_TN_SATURATE; - -template <> -struct MMA_Traits -{ - using ValTypeD = int32_t; - using ValTypeA = int8_t; - using ValTypeB = int8_t; - using ValTypeC = int32_t; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_160,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x32; - using BLayout = GMMA::ABLayout<160, 32>; - using CLayout = GMMA::CLayout_64x160; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) - - -using SM90_64x176x32_S32S8S8_RS_TN = SM90::GMMA::MMA_64x176x32_S32S8S8_RS_TN; - -template <> -struct MMA_Traits -{ - using ValTypeD = int32_t; - using ValTypeA = int8_t; - using ValTypeB = int8_t; - using ValTypeC = int32_t; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_176,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x32; - using BLayout = GMMA::ABLayout<176, 32>; - using CLayout = GMMA::CLayout_64x176; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) - - -using SM90_64x176x32_S32S8S8_RS_TN_SATURATE = SM90::GMMA::MMA_64x176x32_S32S8S8_RS_TN_SATURATE; - -template <> -struct MMA_Traits -{ - using ValTypeD = int32_t; - using ValTypeA = int8_t; - using ValTypeB = int8_t; - using ValTypeC = int32_t; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_176,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x32; - using BLayout = GMMA::ABLayout<176, 32>; - using CLayout = GMMA::CLayout_64x176; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - - - -using SM90_64x192x32_S32S8S8_RS_TN = SM90::GMMA::MMA_64x192x32_S32S8S8_RS_TN; - -template <> -struct MMA_Traits -{ - using ValTypeD = int32_t; - using ValTypeA = int8_t; - using ValTypeB = int8_t; - using ValTypeC = int32_t; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_192,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x32; - using BLayout = GMMA::ABLayout<192, 32>; - using CLayout = GMMA::CLayout_64x192; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - - - -using SM90_64x192x32_S32S8S8_RS_TN_SATURATE = SM90::GMMA::MMA_64x192x32_S32S8S8_RS_TN_SATURATE; - -template <> -struct MMA_Traits -{ - using ValTypeD = int32_t; - using ValTypeA = int8_t; - using ValTypeB = int8_t; - using ValTypeC = int32_t; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_192,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x32; - using BLayout = GMMA::ABLayout<192, 32>; - using CLayout = GMMA::CLayout_64x192; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) - - -using SM90_64x208x32_S32S8S8_RS_TN = SM90::GMMA::MMA_64x208x32_S32S8S8_RS_TN; - -template <> -struct MMA_Traits -{ - using ValTypeD = int32_t; - using ValTypeA = int8_t; - using ValTypeB = int8_t; - using ValTypeC = int32_t; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_208,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x32; - using BLayout = GMMA::ABLayout<208, 32>; - using CLayout = GMMA::CLayout_64x208; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) - - -using SM90_64x208x32_S32S8S8_RS_TN_SATURATE = SM90::GMMA::MMA_64x208x32_S32S8S8_RS_TN_SATURATE; - -template <> -struct MMA_Traits -{ - using ValTypeD = int32_t; - using ValTypeA = int8_t; - using ValTypeB = int8_t; - using ValTypeC = int32_t; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_208,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x32; - using BLayout = GMMA::ABLayout<208, 32>; - using CLayout = GMMA::CLayout_64x208; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) - - -using SM90_64x224x32_S32S8S8_RS_TN = SM90::GMMA::MMA_64x224x32_S32S8S8_RS_TN; - -template <> -struct MMA_Traits -{ - using ValTypeD = int32_t; - using ValTypeA = int8_t; - using ValTypeB = int8_t; - using ValTypeC = int32_t; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_224,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x32; - using BLayout = GMMA::ABLayout<224, 32>; - using CLayout = GMMA::CLayout_64x224; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) - - -using SM90_64x224x32_S32S8S8_RS_TN_SATURATE = SM90::GMMA::MMA_64x224x32_S32S8S8_RS_TN_SATURATE; - -template <> -struct MMA_Traits -{ - using ValTypeD = int32_t; - using ValTypeA = int8_t; - using ValTypeB = int8_t; - using ValTypeC = int32_t; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_224,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x32; - using BLayout = GMMA::ABLayout<224, 32>; - using CLayout = GMMA::CLayout_64x224; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) - - -using SM90_64x240x32_S32S8S8_RS_TN = SM90::GMMA::MMA_64x240x32_S32S8S8_RS_TN; - -template <> -struct MMA_Traits -{ - using ValTypeD = int32_t; - using ValTypeA = int8_t; - using ValTypeB = int8_t; - using ValTypeC = int32_t; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_240,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x32; - using BLayout = GMMA::ABLayout<240, 32>; - using CLayout = GMMA::CLayout_64x240; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) - - -using SM90_64x240x32_S32S8S8_RS_TN_SATURATE = SM90::GMMA::MMA_64x240x32_S32S8S8_RS_TN_SATURATE; - -template <> -struct MMA_Traits -{ - using ValTypeD = int32_t; - using ValTypeA = int8_t; - using ValTypeB = int8_t; - using ValTypeC = int32_t; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_240,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x32; - using BLayout = GMMA::ABLayout<240, 32>; - using CLayout = GMMA::CLayout_64x240; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - - - -using SM90_64x256x32_S32S8S8_RS_TN = SM90::GMMA::MMA_64x256x32_S32S8S8_RS_TN; - -template <> -struct MMA_Traits -{ - using ValTypeD = int32_t; - using ValTypeA = int8_t; - using ValTypeB = int8_t; - using ValTypeC = int32_t; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_256,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x32; - using BLayout = GMMA::ABLayout<256, 32>; - using CLayout = GMMA::CLayout_64x256; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - - - -using SM90_64x256x32_S32S8S8_RS_TN_SATURATE = SM90::GMMA::MMA_64x256x32_S32S8S8_RS_TN_SATURATE; - -template <> -struct MMA_Traits -{ - using ValTypeD = int32_t; - using ValTypeA = int8_t; - using ValTypeB = int8_t; - using ValTypeC = int32_t; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_256,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x32; - using BLayout = GMMA::ABLayout<256, 32>; - using CLayout = GMMA::CLayout_64x256; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - - - -using SM90_64x8x32_S32S8U8_SS_TN = SM90::GMMA::MMA_64x8x32_S32S8U8_SS_TN; - -template <> -struct MMA_Traits -{ - using ValTypeD = int32_t; - using ValTypeA = int8_t; - using ValTypeB = uint8_t; - using ValTypeC = int32_t; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_8,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 32>; - using BLayout = GMMA::ABLayout< 8, 32>; - using CLayout = GMMA::CLayout_64x8; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - - - -using SM90_64x8x32_S32S8U8_SS_TN_SATURATE = SM90::GMMA::MMA_64x8x32_S32S8U8_SS_TN_SATURATE; - -template <> -struct MMA_Traits -{ - using ValTypeD = int32_t; - using ValTypeA = int8_t; - using ValTypeB = uint8_t; - using ValTypeC = int32_t; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_8,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 32>; - using BLayout = GMMA::ABLayout< 8, 32>; - using CLayout = GMMA::CLayout_64x8; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - - - -using SM90_64x16x32_S32S8U8_SS_TN = SM90::GMMA::MMA_64x16x32_S32S8U8_SS_TN; - -template <> -struct MMA_Traits -{ - using ValTypeD = int32_t; - using ValTypeA = int8_t; - using ValTypeB = uint8_t; - using ValTypeC = int32_t; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_16,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 32>; - using BLayout = GMMA::ABLayout< 16, 32>; - using CLayout = GMMA::CLayout_64x16; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - - - -using SM90_64x16x32_S32S8U8_SS_TN_SATURATE = SM90::GMMA::MMA_64x16x32_S32S8U8_SS_TN_SATURATE; - -template <> -struct MMA_Traits -{ - using ValTypeD = int32_t; - using ValTypeA = int8_t; - using ValTypeB = uint8_t; - using ValTypeC = int32_t; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_16,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 32>; - using BLayout = GMMA::ABLayout< 16, 32>; - using CLayout = GMMA::CLayout_64x16; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - - - -using SM90_64x32x32_S32S8U8_SS_TN = SM90::GMMA::MMA_64x32x32_S32S8U8_SS_TN; - -template <> -struct MMA_Traits -{ - using ValTypeD = int32_t; - using ValTypeA = int8_t; - using ValTypeB = uint8_t; - using ValTypeC = int32_t; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_32,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 32>; - using BLayout = GMMA::ABLayout< 32, 32>; - using CLayout = GMMA::CLayout_64x32; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - - - -using SM90_64x32x32_S32S8U8_SS_TN_SATURATE = SM90::GMMA::MMA_64x32x32_S32S8U8_SS_TN_SATURATE; - -template <> -struct MMA_Traits -{ - using ValTypeD = int32_t; - using ValTypeA = int8_t; - using ValTypeB = uint8_t; - using ValTypeC = int32_t; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_32,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 32>; - using BLayout = GMMA::ABLayout< 32, 32>; - using CLayout = GMMA::CLayout_64x32; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) - - -using SM90_64x48x32_S32S8U8_SS_TN = SM90::GMMA::MMA_64x48x32_S32S8U8_SS_TN; - -template <> -struct MMA_Traits -{ - using ValTypeD = int32_t; - using ValTypeA = int8_t; - using ValTypeB = uint8_t; - using ValTypeC = int32_t; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_48,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 32>; - using BLayout = GMMA::ABLayout< 48, 32>; - using CLayout = GMMA::CLayout_64x48; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) - - -using SM90_64x48x32_S32S8U8_SS_TN_SATURATE = SM90::GMMA::MMA_64x48x32_S32S8U8_SS_TN_SATURATE; - -template <> -struct MMA_Traits -{ - using ValTypeD = int32_t; - using ValTypeA = int8_t; - using ValTypeB = uint8_t; - using ValTypeC = int32_t; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_48,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 32>; - using BLayout = GMMA::ABLayout< 48, 32>; - using CLayout = GMMA::CLayout_64x48; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - - - -using SM90_64x64x32_S32S8U8_SS_TN = SM90::GMMA::MMA_64x64x32_S32S8U8_SS_TN; - -template <> -struct MMA_Traits -{ - using ValTypeD = int32_t; - using ValTypeA = int8_t; - using ValTypeB = uint8_t; - using ValTypeC = int32_t; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_64,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 32>; - using BLayout = GMMA::ABLayout< 64, 32>; - using CLayout = GMMA::CLayout_64x64; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - - - -using SM90_64x64x32_S32S8U8_SS_TN_SATURATE = SM90::GMMA::MMA_64x64x32_S32S8U8_SS_TN_SATURATE; - -template <> -struct MMA_Traits -{ - using ValTypeD = int32_t; - using ValTypeA = int8_t; - using ValTypeB = uint8_t; - using ValTypeC = int32_t; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_64,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 32>; - using BLayout = GMMA::ABLayout< 64, 32>; - using CLayout = GMMA::CLayout_64x64; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) - - -using SM90_64x80x32_S32S8U8_SS_TN = SM90::GMMA::MMA_64x80x32_S32S8U8_SS_TN; - -template <> -struct MMA_Traits -{ - using ValTypeD = int32_t; - using ValTypeA = int8_t; - using ValTypeB = uint8_t; - using ValTypeC = int32_t; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_80,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 32>; - using BLayout = GMMA::ABLayout< 80, 32>; - using CLayout = GMMA::CLayout_64x80; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) - - -using SM90_64x80x32_S32S8U8_SS_TN_SATURATE = SM90::GMMA::MMA_64x80x32_S32S8U8_SS_TN_SATURATE; - -template <> -struct MMA_Traits -{ - using ValTypeD = int32_t; - using ValTypeA = int8_t; - using ValTypeB = uint8_t; - using ValTypeC = int32_t; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_80,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 32>; - using BLayout = GMMA::ABLayout< 80, 32>; - using CLayout = GMMA::CLayout_64x80; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - - - -using SM90_64x96x32_S32S8U8_SS_TN = SM90::GMMA::MMA_64x96x32_S32S8U8_SS_TN; - -template <> -struct MMA_Traits -{ - using ValTypeD = int32_t; - using ValTypeA = int8_t; - using ValTypeB = uint8_t; - using ValTypeC = int32_t; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_96,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 32>; - using BLayout = GMMA::ABLayout< 96, 32>; - using CLayout = GMMA::CLayout_64x96; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - - - -using SM90_64x96x32_S32S8U8_SS_TN_SATURATE = SM90::GMMA::MMA_64x96x32_S32S8U8_SS_TN_SATURATE; - -template <> -struct MMA_Traits -{ - using ValTypeD = int32_t; - using ValTypeA = int8_t; - using ValTypeB = uint8_t; - using ValTypeC = int32_t; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_96,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 32>; - using BLayout = GMMA::ABLayout< 96, 32>; - using CLayout = GMMA::CLayout_64x96; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) - - -using SM90_64x112x32_S32S8U8_SS_TN = SM90::GMMA::MMA_64x112x32_S32S8U8_SS_TN; - -template <> -struct MMA_Traits -{ - using ValTypeD = int32_t; - using ValTypeA = int8_t; - using ValTypeB = uint8_t; - using ValTypeC = int32_t; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_112,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 32>; - using BLayout = GMMA::ABLayout<112, 32>; - using CLayout = GMMA::CLayout_64x112; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) - - -using SM90_64x112x32_S32S8U8_SS_TN_SATURATE = SM90::GMMA::MMA_64x112x32_S32S8U8_SS_TN_SATURATE; - -template <> -struct MMA_Traits -{ - using ValTypeD = int32_t; - using ValTypeA = int8_t; - using ValTypeB = uint8_t; - using ValTypeC = int32_t; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_112,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 32>; - using BLayout = GMMA::ABLayout<112, 32>; - using CLayout = GMMA::CLayout_64x112; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - - - -using SM90_64x128x32_S32S8U8_SS_TN = SM90::GMMA::MMA_64x128x32_S32S8U8_SS_TN; - -template <> -struct MMA_Traits -{ - using ValTypeD = int32_t; - using ValTypeA = int8_t; - using ValTypeB = uint8_t; - using ValTypeC = int32_t; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_128,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 32>; - using BLayout = GMMA::ABLayout<128, 32>; - using CLayout = GMMA::CLayout_64x128; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - - - -using SM90_64x128x32_S32S8U8_SS_TN_SATURATE = SM90::GMMA::MMA_64x128x32_S32S8U8_SS_TN_SATURATE; - -template <> -struct MMA_Traits -{ - using ValTypeD = int32_t; - using ValTypeA = int8_t; - using ValTypeB = uint8_t; - using ValTypeC = int32_t; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_128,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 32>; - using BLayout = GMMA::ABLayout<128, 32>; - using CLayout = GMMA::CLayout_64x128; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) - - -using SM90_64x144x32_S32S8U8_SS_TN = SM90::GMMA::MMA_64x144x32_S32S8U8_SS_TN; - -template <> -struct MMA_Traits -{ - using ValTypeD = int32_t; - using ValTypeA = int8_t; - using ValTypeB = uint8_t; - using ValTypeC = int32_t; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_144,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 32>; - using BLayout = GMMA::ABLayout<144, 32>; - using CLayout = GMMA::CLayout_64x144; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) - - -using SM90_64x144x32_S32S8U8_SS_TN_SATURATE = SM90::GMMA::MMA_64x144x32_S32S8U8_SS_TN_SATURATE; - -template <> -struct MMA_Traits -{ - using ValTypeD = int32_t; - using ValTypeA = int8_t; - using ValTypeB = uint8_t; - using ValTypeC = int32_t; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_144,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 32>; - using BLayout = GMMA::ABLayout<144, 32>; - using CLayout = GMMA::CLayout_64x144; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) - - -using SM90_64x160x32_S32S8U8_SS_TN = SM90::GMMA::MMA_64x160x32_S32S8U8_SS_TN; - -template <> -struct MMA_Traits -{ - using ValTypeD = int32_t; - using ValTypeA = int8_t; - using ValTypeB = uint8_t; - using ValTypeC = int32_t; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_160,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 32>; - using BLayout = GMMA::ABLayout<160, 32>; - using CLayout = GMMA::CLayout_64x160; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) - - -using SM90_64x160x32_S32S8U8_SS_TN_SATURATE = SM90::GMMA::MMA_64x160x32_S32S8U8_SS_TN_SATURATE; - -template <> -struct MMA_Traits -{ - using ValTypeD = int32_t; - using ValTypeA = int8_t; - using ValTypeB = uint8_t; - using ValTypeC = int32_t; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_160,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 32>; - using BLayout = GMMA::ABLayout<160, 32>; - using CLayout = GMMA::CLayout_64x160; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) - - -using SM90_64x176x32_S32S8U8_SS_TN = SM90::GMMA::MMA_64x176x32_S32S8U8_SS_TN; - -template <> -struct MMA_Traits -{ - using ValTypeD = int32_t; - using ValTypeA = int8_t; - using ValTypeB = uint8_t; - using ValTypeC = int32_t; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_176,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 32>; - using BLayout = GMMA::ABLayout<176, 32>; - using CLayout = GMMA::CLayout_64x176; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) - - -using SM90_64x176x32_S32S8U8_SS_TN_SATURATE = SM90::GMMA::MMA_64x176x32_S32S8U8_SS_TN_SATURATE; - -template <> -struct MMA_Traits -{ - using ValTypeD = int32_t; - using ValTypeA = int8_t; - using ValTypeB = uint8_t; - using ValTypeC = int32_t; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_176,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 32>; - using BLayout = GMMA::ABLayout<176, 32>; - using CLayout = GMMA::CLayout_64x176; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - - - -using SM90_64x192x32_S32S8U8_SS_TN = SM90::GMMA::MMA_64x192x32_S32S8U8_SS_TN; - -template <> -struct MMA_Traits -{ - using ValTypeD = int32_t; - using ValTypeA = int8_t; - using ValTypeB = uint8_t; - using ValTypeC = int32_t; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_192,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 32>; - using BLayout = GMMA::ABLayout<192, 32>; - using CLayout = GMMA::CLayout_64x192; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - - - -using SM90_64x192x32_S32S8U8_SS_TN_SATURATE = SM90::GMMA::MMA_64x192x32_S32S8U8_SS_TN_SATURATE; - -template <> -struct MMA_Traits -{ - using ValTypeD = int32_t; - using ValTypeA = int8_t; - using ValTypeB = uint8_t; - using ValTypeC = int32_t; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_192,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 32>; - using BLayout = GMMA::ABLayout<192, 32>; - using CLayout = GMMA::CLayout_64x192; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) - - -using SM90_64x208x32_S32S8U8_SS_TN = SM90::GMMA::MMA_64x208x32_S32S8U8_SS_TN; - -template <> -struct MMA_Traits -{ - using ValTypeD = int32_t; - using ValTypeA = int8_t; - using ValTypeB = uint8_t; - using ValTypeC = int32_t; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_208,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 32>; - using BLayout = GMMA::ABLayout<208, 32>; - using CLayout = GMMA::CLayout_64x208; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) - - -using SM90_64x208x32_S32S8U8_SS_TN_SATURATE = SM90::GMMA::MMA_64x208x32_S32S8U8_SS_TN_SATURATE; - -template <> -struct MMA_Traits -{ - using ValTypeD = int32_t; - using ValTypeA = int8_t; - using ValTypeB = uint8_t; - using ValTypeC = int32_t; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_208,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 32>; - using BLayout = GMMA::ABLayout<208, 32>; - using CLayout = GMMA::CLayout_64x208; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) - - -using SM90_64x224x32_S32S8U8_SS_TN = SM90::GMMA::MMA_64x224x32_S32S8U8_SS_TN; - -template <> -struct MMA_Traits -{ - using ValTypeD = int32_t; - using ValTypeA = int8_t; - using ValTypeB = uint8_t; - using ValTypeC = int32_t; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_224,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 32>; - using BLayout = GMMA::ABLayout<224, 32>; - using CLayout = GMMA::CLayout_64x224; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) - - -using SM90_64x224x32_S32S8U8_SS_TN_SATURATE = SM90::GMMA::MMA_64x224x32_S32S8U8_SS_TN_SATURATE; - -template <> -struct MMA_Traits -{ - using ValTypeD = int32_t; - using ValTypeA = int8_t; - using ValTypeB = uint8_t; - using ValTypeC = int32_t; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_224,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 32>; - using BLayout = GMMA::ABLayout<224, 32>; - using CLayout = GMMA::CLayout_64x224; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) - - -using SM90_64x240x32_S32S8U8_SS_TN = SM90::GMMA::MMA_64x240x32_S32S8U8_SS_TN; - -template <> -struct MMA_Traits -{ - using ValTypeD = int32_t; - using ValTypeA = int8_t; - using ValTypeB = uint8_t; - using ValTypeC = int32_t; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_240,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 32>; - using BLayout = GMMA::ABLayout<240, 32>; - using CLayout = GMMA::CLayout_64x240; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) - - -using SM90_64x240x32_S32S8U8_SS_TN_SATURATE = SM90::GMMA::MMA_64x240x32_S32S8U8_SS_TN_SATURATE; - -template <> -struct MMA_Traits -{ - using ValTypeD = int32_t; - using ValTypeA = int8_t; - using ValTypeB = uint8_t; - using ValTypeC = int32_t; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_240,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 32>; - using BLayout = GMMA::ABLayout<240, 32>; - using CLayout = GMMA::CLayout_64x240; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - - - -using SM90_64x256x32_S32S8U8_SS_TN = SM90::GMMA::MMA_64x256x32_S32S8U8_SS_TN; - -template <> -struct MMA_Traits -{ - using ValTypeD = int32_t; - using ValTypeA = int8_t; - using ValTypeB = uint8_t; - using ValTypeC = int32_t; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_256,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 32>; - using BLayout = GMMA::ABLayout<256, 32>; - using CLayout = GMMA::CLayout_64x256; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - - - -using SM90_64x256x32_S32S8U8_SS_TN_SATURATE = SM90::GMMA::MMA_64x256x32_S32S8U8_SS_TN_SATURATE; - -template <> -struct MMA_Traits -{ - using ValTypeD = int32_t; - using ValTypeA = int8_t; - using ValTypeB = uint8_t; - using ValTypeC = int32_t; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_256,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 32>; - using BLayout = GMMA::ABLayout<256, 32>; - using CLayout = GMMA::CLayout_64x256; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - - - -using SM90_64x8x32_S32S8U8_RS_TN = SM90::GMMA::MMA_64x8x32_S32S8U8_RS_TN; - -template <> -struct MMA_Traits -{ - using ValTypeD = int32_t; - using ValTypeA = int8_t; - using ValTypeB = uint8_t; - using ValTypeC = int32_t; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_8,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x32; - using BLayout = GMMA::ABLayout< 8, 32>; - using CLayout = GMMA::CLayout_64x8; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - - - -using SM90_64x8x32_S32S8U8_RS_TN_SATURATE = SM90::GMMA::MMA_64x8x32_S32S8U8_RS_TN_SATURATE; - -template <> -struct MMA_Traits -{ - using ValTypeD = int32_t; - using ValTypeA = int8_t; - using ValTypeB = uint8_t; - using ValTypeC = int32_t; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_8,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x32; - using BLayout = GMMA::ABLayout< 8, 32>; - using CLayout = GMMA::CLayout_64x8; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - - - -using SM90_64x16x32_S32S8U8_RS_TN = SM90::GMMA::MMA_64x16x32_S32S8U8_RS_TN; - -template <> -struct MMA_Traits -{ - using ValTypeD = int32_t; - using ValTypeA = int8_t; - using ValTypeB = uint8_t; - using ValTypeC = int32_t; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_16,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x32; - using BLayout = GMMA::ABLayout< 16, 32>; - using CLayout = GMMA::CLayout_64x16; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - - - -using SM90_64x16x32_S32S8U8_RS_TN_SATURATE = SM90::GMMA::MMA_64x16x32_S32S8U8_RS_TN_SATURATE; - -template <> -struct MMA_Traits -{ - using ValTypeD = int32_t; - using ValTypeA = int8_t; - using ValTypeB = uint8_t; - using ValTypeC = int32_t; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_16,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x32; - using BLayout = GMMA::ABLayout< 16, 32>; - using CLayout = GMMA::CLayout_64x16; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - - - -using SM90_64x32x32_S32S8U8_RS_TN = SM90::GMMA::MMA_64x32x32_S32S8U8_RS_TN; - -template <> -struct MMA_Traits -{ - using ValTypeD = int32_t; - using ValTypeA = int8_t; - using ValTypeB = uint8_t; - using ValTypeC = int32_t; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_32,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x32; - using BLayout = GMMA::ABLayout< 32, 32>; - using CLayout = GMMA::CLayout_64x32; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - - - -using SM90_64x32x32_S32S8U8_RS_TN_SATURATE = SM90::GMMA::MMA_64x32x32_S32S8U8_RS_TN_SATURATE; - -template <> -struct MMA_Traits -{ - using ValTypeD = int32_t; - using ValTypeA = int8_t; - using ValTypeB = uint8_t; - using ValTypeC = int32_t; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_32,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x32; - using BLayout = GMMA::ABLayout< 32, 32>; - using CLayout = GMMA::CLayout_64x32; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) - - -using SM90_64x48x32_S32S8U8_RS_TN = SM90::GMMA::MMA_64x48x32_S32S8U8_RS_TN; - -template <> -struct MMA_Traits -{ - using ValTypeD = int32_t; - using ValTypeA = int8_t; - using ValTypeB = uint8_t; - using ValTypeC = int32_t; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_48,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x32; - using BLayout = GMMA::ABLayout< 48, 32>; - using CLayout = GMMA::CLayout_64x48; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) - - -using SM90_64x48x32_S32S8U8_RS_TN_SATURATE = SM90::GMMA::MMA_64x48x32_S32S8U8_RS_TN_SATURATE; - -template <> -struct MMA_Traits -{ - using ValTypeD = int32_t; - using ValTypeA = int8_t; - using ValTypeB = uint8_t; - using ValTypeC = int32_t; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_48,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x32; - using BLayout = GMMA::ABLayout< 48, 32>; - using CLayout = GMMA::CLayout_64x48; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - - - -using SM90_64x64x32_S32S8U8_RS_TN = SM90::GMMA::MMA_64x64x32_S32S8U8_RS_TN; - -template <> -struct MMA_Traits -{ - using ValTypeD = int32_t; - using ValTypeA = int8_t; - using ValTypeB = uint8_t; - using ValTypeC = int32_t; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_64,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x32; - using BLayout = GMMA::ABLayout< 64, 32>; - using CLayout = GMMA::CLayout_64x64; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - - - -using SM90_64x64x32_S32S8U8_RS_TN_SATURATE = SM90::GMMA::MMA_64x64x32_S32S8U8_RS_TN_SATURATE; - -template <> -struct MMA_Traits -{ - using ValTypeD = int32_t; - using ValTypeA = int8_t; - using ValTypeB = uint8_t; - using ValTypeC = int32_t; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_64,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x32; - using BLayout = GMMA::ABLayout< 64, 32>; - using CLayout = GMMA::CLayout_64x64; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) - - -using SM90_64x80x32_S32S8U8_RS_TN = SM90::GMMA::MMA_64x80x32_S32S8U8_RS_TN; - -template <> -struct MMA_Traits -{ - using ValTypeD = int32_t; - using ValTypeA = int8_t; - using ValTypeB = uint8_t; - using ValTypeC = int32_t; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_80,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x32; - using BLayout = GMMA::ABLayout< 80, 32>; - using CLayout = GMMA::CLayout_64x80; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) - - -using SM90_64x80x32_S32S8U8_RS_TN_SATURATE = SM90::GMMA::MMA_64x80x32_S32S8U8_RS_TN_SATURATE; - -template <> -struct MMA_Traits -{ - using ValTypeD = int32_t; - using ValTypeA = int8_t; - using ValTypeB = uint8_t; - using ValTypeC = int32_t; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_80,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x32; - using BLayout = GMMA::ABLayout< 80, 32>; - using CLayout = GMMA::CLayout_64x80; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - - - -using SM90_64x96x32_S32S8U8_RS_TN = SM90::GMMA::MMA_64x96x32_S32S8U8_RS_TN; - -template <> -struct MMA_Traits -{ - using ValTypeD = int32_t; - using ValTypeA = int8_t; - using ValTypeB = uint8_t; - using ValTypeC = int32_t; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_96,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x32; - using BLayout = GMMA::ABLayout< 96, 32>; - using CLayout = GMMA::CLayout_64x96; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - - - -using SM90_64x96x32_S32S8U8_RS_TN_SATURATE = SM90::GMMA::MMA_64x96x32_S32S8U8_RS_TN_SATURATE; - -template <> -struct MMA_Traits -{ - using ValTypeD = int32_t; - using ValTypeA = int8_t; - using ValTypeB = uint8_t; - using ValTypeC = int32_t; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_96,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x32; - using BLayout = GMMA::ABLayout< 96, 32>; - using CLayout = GMMA::CLayout_64x96; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) - - -using SM90_64x112x32_S32S8U8_RS_TN = SM90::GMMA::MMA_64x112x32_S32S8U8_RS_TN; - -template <> -struct MMA_Traits -{ - using ValTypeD = int32_t; - using ValTypeA = int8_t; - using ValTypeB = uint8_t; - using ValTypeC = int32_t; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_112,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x32; - using BLayout = GMMA::ABLayout<112, 32>; - using CLayout = GMMA::CLayout_64x112; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) - - -using SM90_64x112x32_S32S8U8_RS_TN_SATURATE = SM90::GMMA::MMA_64x112x32_S32S8U8_RS_TN_SATURATE; - -template <> -struct MMA_Traits -{ - using ValTypeD = int32_t; - using ValTypeA = int8_t; - using ValTypeB = uint8_t; - using ValTypeC = int32_t; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_112,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x32; - using BLayout = GMMA::ABLayout<112, 32>; - using CLayout = GMMA::CLayout_64x112; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - - - -using SM90_64x128x32_S32S8U8_RS_TN = SM90::GMMA::MMA_64x128x32_S32S8U8_RS_TN; - -template <> -struct MMA_Traits -{ - using ValTypeD = int32_t; - using ValTypeA = int8_t; - using ValTypeB = uint8_t; - using ValTypeC = int32_t; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_128,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x32; - using BLayout = GMMA::ABLayout<128, 32>; - using CLayout = GMMA::CLayout_64x128; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - - - -using SM90_64x128x32_S32S8U8_RS_TN_SATURATE = SM90::GMMA::MMA_64x128x32_S32S8U8_RS_TN_SATURATE; - -template <> -struct MMA_Traits -{ - using ValTypeD = int32_t; - using ValTypeA = int8_t; - using ValTypeB = uint8_t; - using ValTypeC = int32_t; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_128,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x32; - using BLayout = GMMA::ABLayout<128, 32>; - using CLayout = GMMA::CLayout_64x128; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) - - -using SM90_64x144x32_S32S8U8_RS_TN = SM90::GMMA::MMA_64x144x32_S32S8U8_RS_TN; - -template <> -struct MMA_Traits -{ - using ValTypeD = int32_t; - using ValTypeA = int8_t; - using ValTypeB = uint8_t; - using ValTypeC = int32_t; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_144,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x32; - using BLayout = GMMA::ABLayout<144, 32>; - using CLayout = GMMA::CLayout_64x144; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) - - -using SM90_64x144x32_S32S8U8_RS_TN_SATURATE = SM90::GMMA::MMA_64x144x32_S32S8U8_RS_TN_SATURATE; - -template <> -struct MMA_Traits -{ - using ValTypeD = int32_t; - using ValTypeA = int8_t; - using ValTypeB = uint8_t; - using ValTypeC = int32_t; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_144,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x32; - using BLayout = GMMA::ABLayout<144, 32>; - using CLayout = GMMA::CLayout_64x144; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) - - -using SM90_64x160x32_S32S8U8_RS_TN = SM90::GMMA::MMA_64x160x32_S32S8U8_RS_TN; - -template <> -struct MMA_Traits -{ - using ValTypeD = int32_t; - using ValTypeA = int8_t; - using ValTypeB = uint8_t; - using ValTypeC = int32_t; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_160,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x32; - using BLayout = GMMA::ABLayout<160, 32>; - using CLayout = GMMA::CLayout_64x160; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) - - -using SM90_64x160x32_S32S8U8_RS_TN_SATURATE = SM90::GMMA::MMA_64x160x32_S32S8U8_RS_TN_SATURATE; - -template <> -struct MMA_Traits -{ - using ValTypeD = int32_t; - using ValTypeA = int8_t; - using ValTypeB = uint8_t; - using ValTypeC = int32_t; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_160,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x32; - using BLayout = GMMA::ABLayout<160, 32>; - using CLayout = GMMA::CLayout_64x160; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) - - -using SM90_64x176x32_S32S8U8_RS_TN = SM90::GMMA::MMA_64x176x32_S32S8U8_RS_TN; - -template <> -struct MMA_Traits -{ - using ValTypeD = int32_t; - using ValTypeA = int8_t; - using ValTypeB = uint8_t; - using ValTypeC = int32_t; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_176,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x32; - using BLayout = GMMA::ABLayout<176, 32>; - using CLayout = GMMA::CLayout_64x176; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) - - -using SM90_64x176x32_S32S8U8_RS_TN_SATURATE = SM90::GMMA::MMA_64x176x32_S32S8U8_RS_TN_SATURATE; - -template <> -struct MMA_Traits -{ - using ValTypeD = int32_t; - using ValTypeA = int8_t; - using ValTypeB = uint8_t; - using ValTypeC = int32_t; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_176,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x32; - using BLayout = GMMA::ABLayout<176, 32>; - using CLayout = GMMA::CLayout_64x176; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - - - -using SM90_64x192x32_S32S8U8_RS_TN = SM90::GMMA::MMA_64x192x32_S32S8U8_RS_TN; - -template <> -struct MMA_Traits -{ - using ValTypeD = int32_t; - using ValTypeA = int8_t; - using ValTypeB = uint8_t; - using ValTypeC = int32_t; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_192,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x32; - using BLayout = GMMA::ABLayout<192, 32>; - using CLayout = GMMA::CLayout_64x192; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - - - -using SM90_64x192x32_S32S8U8_RS_TN_SATURATE = SM90::GMMA::MMA_64x192x32_S32S8U8_RS_TN_SATURATE; - -template <> -struct MMA_Traits -{ - using ValTypeD = int32_t; - using ValTypeA = int8_t; - using ValTypeB = uint8_t; - using ValTypeC = int32_t; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_192,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x32; - using BLayout = GMMA::ABLayout<192, 32>; - using CLayout = GMMA::CLayout_64x192; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) - - -using SM90_64x208x32_S32S8U8_RS_TN = SM90::GMMA::MMA_64x208x32_S32S8U8_RS_TN; - -template <> -struct MMA_Traits -{ - using ValTypeD = int32_t; - using ValTypeA = int8_t; - using ValTypeB = uint8_t; - using ValTypeC = int32_t; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_208,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x32; - using BLayout = GMMA::ABLayout<208, 32>; - using CLayout = GMMA::CLayout_64x208; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) - - -using SM90_64x208x32_S32S8U8_RS_TN_SATURATE = SM90::GMMA::MMA_64x208x32_S32S8U8_RS_TN_SATURATE; - -template <> -struct MMA_Traits -{ - using ValTypeD = int32_t; - using ValTypeA = int8_t; - using ValTypeB = uint8_t; - using ValTypeC = int32_t; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_208,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x32; - using BLayout = GMMA::ABLayout<208, 32>; - using CLayout = GMMA::CLayout_64x208; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) - - -using SM90_64x224x32_S32S8U8_RS_TN = SM90::GMMA::MMA_64x224x32_S32S8U8_RS_TN; - -template <> -struct MMA_Traits -{ - using ValTypeD = int32_t; - using ValTypeA = int8_t; - using ValTypeB = uint8_t; - using ValTypeC = int32_t; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_224,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x32; - using BLayout = GMMA::ABLayout<224, 32>; - using CLayout = GMMA::CLayout_64x224; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) - - -using SM90_64x224x32_S32S8U8_RS_TN_SATURATE = SM90::GMMA::MMA_64x224x32_S32S8U8_RS_TN_SATURATE; - -template <> -struct MMA_Traits -{ - using ValTypeD = int32_t; - using ValTypeA = int8_t; - using ValTypeB = uint8_t; - using ValTypeC = int32_t; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_224,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x32; - using BLayout = GMMA::ABLayout<224, 32>; - using CLayout = GMMA::CLayout_64x224; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) - - -using SM90_64x240x32_S32S8U8_RS_TN = SM90::GMMA::MMA_64x240x32_S32S8U8_RS_TN; - -template <> -struct MMA_Traits -{ - using ValTypeD = int32_t; - using ValTypeA = int8_t; - using ValTypeB = uint8_t; - using ValTypeC = int32_t; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_240,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x32; - using BLayout = GMMA::ABLayout<240, 32>; - using CLayout = GMMA::CLayout_64x240; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) - - -using SM90_64x240x32_S32S8U8_RS_TN_SATURATE = SM90::GMMA::MMA_64x240x32_S32S8U8_RS_TN_SATURATE; - -template <> -struct MMA_Traits -{ - using ValTypeD = int32_t; - using ValTypeA = int8_t; - using ValTypeB = uint8_t; - using ValTypeC = int32_t; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_240,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x32; - using BLayout = GMMA::ABLayout<240, 32>; - using CLayout = GMMA::CLayout_64x240; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - - - -using SM90_64x256x32_S32S8U8_RS_TN = SM90::GMMA::MMA_64x256x32_S32S8U8_RS_TN; - -template <> -struct MMA_Traits -{ - using ValTypeD = int32_t; - using ValTypeA = int8_t; - using ValTypeB = uint8_t; - using ValTypeC = int32_t; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_256,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x32; - using BLayout = GMMA::ABLayout<256, 32>; - using CLayout = GMMA::CLayout_64x256; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - - - -using SM90_64x256x32_S32S8U8_RS_TN_SATURATE = SM90::GMMA::MMA_64x256x32_S32S8U8_RS_TN_SATURATE; - -template <> -struct MMA_Traits -{ - using ValTypeD = int32_t; - using ValTypeA = int8_t; - using ValTypeB = uint8_t; - using ValTypeC = int32_t; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_256,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x32; - using BLayout = GMMA::ABLayout<256, 32>; - using CLayout = GMMA::CLayout_64x256; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - - - -using SM90_64x8x32_S32U8S8_SS_TN = SM90::GMMA::MMA_64x8x32_S32U8S8_SS_TN; - -template <> -struct MMA_Traits -{ - using ValTypeD = int32_t; - using ValTypeA = uint8_t; - using ValTypeB = int8_t; - using ValTypeC = int32_t; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_8,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 32>; - using BLayout = GMMA::ABLayout< 8, 32>; - using CLayout = GMMA::CLayout_64x8; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - - - -using SM90_64x8x32_S32U8S8_SS_TN_SATURATE = SM90::GMMA::MMA_64x8x32_S32U8S8_SS_TN_SATURATE; - -template <> -struct MMA_Traits -{ - using ValTypeD = int32_t; - using ValTypeA = uint8_t; - using ValTypeB = int8_t; - using ValTypeC = int32_t; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_8,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 32>; - using BLayout = GMMA::ABLayout< 8, 32>; - using CLayout = GMMA::CLayout_64x8; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - - - -using SM90_64x16x32_S32U8S8_SS_TN = SM90::GMMA::MMA_64x16x32_S32U8S8_SS_TN; - -template <> -struct MMA_Traits -{ - using ValTypeD = int32_t; - using ValTypeA = uint8_t; - using ValTypeB = int8_t; - using ValTypeC = int32_t; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_16,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 32>; - using BLayout = GMMA::ABLayout< 16, 32>; - using CLayout = GMMA::CLayout_64x16; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - - - -using SM90_64x16x32_S32U8S8_SS_TN_SATURATE = SM90::GMMA::MMA_64x16x32_S32U8S8_SS_TN_SATURATE; - -template <> -struct MMA_Traits -{ - using ValTypeD = int32_t; - using ValTypeA = uint8_t; - using ValTypeB = int8_t; - using ValTypeC = int32_t; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_16,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 32>; - using BLayout = GMMA::ABLayout< 16, 32>; - using CLayout = GMMA::CLayout_64x16; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - - - -using SM90_64x32x32_S32U8S8_SS_TN = SM90::GMMA::MMA_64x32x32_S32U8S8_SS_TN; - -template <> -struct MMA_Traits -{ - using ValTypeD = int32_t; - using ValTypeA = uint8_t; - using ValTypeB = int8_t; - using ValTypeC = int32_t; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_32,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 32>; - using BLayout = GMMA::ABLayout< 32, 32>; - using CLayout = GMMA::CLayout_64x32; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - - - -using SM90_64x32x32_S32U8S8_SS_TN_SATURATE = SM90::GMMA::MMA_64x32x32_S32U8S8_SS_TN_SATURATE; - -template <> -struct MMA_Traits -{ - using ValTypeD = int32_t; - using ValTypeA = uint8_t; - using ValTypeB = int8_t; - using ValTypeC = int32_t; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_32,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 32>; - using BLayout = GMMA::ABLayout< 32, 32>; - using CLayout = GMMA::CLayout_64x32; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) - - -using SM90_64x48x32_S32U8S8_SS_TN = SM90::GMMA::MMA_64x48x32_S32U8S8_SS_TN; - -template <> -struct MMA_Traits -{ - using ValTypeD = int32_t; - using ValTypeA = uint8_t; - using ValTypeB = int8_t; - using ValTypeC = int32_t; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_48,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 32>; - using BLayout = GMMA::ABLayout< 48, 32>; - using CLayout = GMMA::CLayout_64x48; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) - - -using SM90_64x48x32_S32U8S8_SS_TN_SATURATE = SM90::GMMA::MMA_64x48x32_S32U8S8_SS_TN_SATURATE; - -template <> -struct MMA_Traits -{ - using ValTypeD = int32_t; - using ValTypeA = uint8_t; - using ValTypeB = int8_t; - using ValTypeC = int32_t; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_48,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 32>; - using BLayout = GMMA::ABLayout< 48, 32>; - using CLayout = GMMA::CLayout_64x48; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - - - -using SM90_64x64x32_S32U8S8_SS_TN = SM90::GMMA::MMA_64x64x32_S32U8S8_SS_TN; - -template <> -struct MMA_Traits -{ - using ValTypeD = int32_t; - using ValTypeA = uint8_t; - using ValTypeB = int8_t; - using ValTypeC = int32_t; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_64,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 32>; - using BLayout = GMMA::ABLayout< 64, 32>; - using CLayout = GMMA::CLayout_64x64; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - - - -using SM90_64x64x32_S32U8S8_SS_TN_SATURATE = SM90::GMMA::MMA_64x64x32_S32U8S8_SS_TN_SATURATE; - -template <> -struct MMA_Traits -{ - using ValTypeD = int32_t; - using ValTypeA = uint8_t; - using ValTypeB = int8_t; - using ValTypeC = int32_t; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_64,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 32>; - using BLayout = GMMA::ABLayout< 64, 32>; - using CLayout = GMMA::CLayout_64x64; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) - - -using SM90_64x80x32_S32U8S8_SS_TN = SM90::GMMA::MMA_64x80x32_S32U8S8_SS_TN; - -template <> -struct MMA_Traits -{ - using ValTypeD = int32_t; - using ValTypeA = uint8_t; - using ValTypeB = int8_t; - using ValTypeC = int32_t; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_80,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 32>; - using BLayout = GMMA::ABLayout< 80, 32>; - using CLayout = GMMA::CLayout_64x80; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) - - -using SM90_64x80x32_S32U8S8_SS_TN_SATURATE = SM90::GMMA::MMA_64x80x32_S32U8S8_SS_TN_SATURATE; - -template <> -struct MMA_Traits -{ - using ValTypeD = int32_t; - using ValTypeA = uint8_t; - using ValTypeB = int8_t; - using ValTypeC = int32_t; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_80,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 32>; - using BLayout = GMMA::ABLayout< 80, 32>; - using CLayout = GMMA::CLayout_64x80; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - - - -using SM90_64x96x32_S32U8S8_SS_TN = SM90::GMMA::MMA_64x96x32_S32U8S8_SS_TN; - -template <> -struct MMA_Traits -{ - using ValTypeD = int32_t; - using ValTypeA = uint8_t; - using ValTypeB = int8_t; - using ValTypeC = int32_t; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_96,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 32>; - using BLayout = GMMA::ABLayout< 96, 32>; - using CLayout = GMMA::CLayout_64x96; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - - - -using SM90_64x96x32_S32U8S8_SS_TN_SATURATE = SM90::GMMA::MMA_64x96x32_S32U8S8_SS_TN_SATURATE; - -template <> -struct MMA_Traits -{ - using ValTypeD = int32_t; - using ValTypeA = uint8_t; - using ValTypeB = int8_t; - using ValTypeC = int32_t; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_96,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 32>; - using BLayout = GMMA::ABLayout< 96, 32>; - using CLayout = GMMA::CLayout_64x96; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) - - -using SM90_64x112x32_S32U8S8_SS_TN = SM90::GMMA::MMA_64x112x32_S32U8S8_SS_TN; - -template <> -struct MMA_Traits -{ - using ValTypeD = int32_t; - using ValTypeA = uint8_t; - using ValTypeB = int8_t; - using ValTypeC = int32_t; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_112,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 32>; - using BLayout = GMMA::ABLayout<112, 32>; - using CLayout = GMMA::CLayout_64x112; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) - - -using SM90_64x112x32_S32U8S8_SS_TN_SATURATE = SM90::GMMA::MMA_64x112x32_S32U8S8_SS_TN_SATURATE; - -template <> -struct MMA_Traits -{ - using ValTypeD = int32_t; - using ValTypeA = uint8_t; - using ValTypeB = int8_t; - using ValTypeC = int32_t; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_112,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 32>; - using BLayout = GMMA::ABLayout<112, 32>; - using CLayout = GMMA::CLayout_64x112; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - - - -using SM90_64x128x32_S32U8S8_SS_TN = SM90::GMMA::MMA_64x128x32_S32U8S8_SS_TN; - -template <> -struct MMA_Traits -{ - using ValTypeD = int32_t; - using ValTypeA = uint8_t; - using ValTypeB = int8_t; - using ValTypeC = int32_t; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_128,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 32>; - using BLayout = GMMA::ABLayout<128, 32>; - using CLayout = GMMA::CLayout_64x128; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - - - -using SM90_64x128x32_S32U8S8_SS_TN_SATURATE = SM90::GMMA::MMA_64x128x32_S32U8S8_SS_TN_SATURATE; - -template <> -struct MMA_Traits -{ - using ValTypeD = int32_t; - using ValTypeA = uint8_t; - using ValTypeB = int8_t; - using ValTypeC = int32_t; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_128,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 32>; - using BLayout = GMMA::ABLayout<128, 32>; - using CLayout = GMMA::CLayout_64x128; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) - - -using SM90_64x144x32_S32U8S8_SS_TN = SM90::GMMA::MMA_64x144x32_S32U8S8_SS_TN; - -template <> -struct MMA_Traits -{ - using ValTypeD = int32_t; - using ValTypeA = uint8_t; - using ValTypeB = int8_t; - using ValTypeC = int32_t; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_144,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 32>; - using BLayout = GMMA::ABLayout<144, 32>; - using CLayout = GMMA::CLayout_64x144; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) - - -using SM90_64x144x32_S32U8S8_SS_TN_SATURATE = SM90::GMMA::MMA_64x144x32_S32U8S8_SS_TN_SATURATE; - -template <> -struct MMA_Traits -{ - using ValTypeD = int32_t; - using ValTypeA = uint8_t; - using ValTypeB = int8_t; - using ValTypeC = int32_t; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_144,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 32>; - using BLayout = GMMA::ABLayout<144, 32>; - using CLayout = GMMA::CLayout_64x144; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) - - -using SM90_64x160x32_S32U8S8_SS_TN = SM90::GMMA::MMA_64x160x32_S32U8S8_SS_TN; - -template <> -struct MMA_Traits -{ - using ValTypeD = int32_t; - using ValTypeA = uint8_t; - using ValTypeB = int8_t; - using ValTypeC = int32_t; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_160,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 32>; - using BLayout = GMMA::ABLayout<160, 32>; - using CLayout = GMMA::CLayout_64x160; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) - - -using SM90_64x160x32_S32U8S8_SS_TN_SATURATE = SM90::GMMA::MMA_64x160x32_S32U8S8_SS_TN_SATURATE; - -template <> -struct MMA_Traits -{ - using ValTypeD = int32_t; - using ValTypeA = uint8_t; - using ValTypeB = int8_t; - using ValTypeC = int32_t; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_160,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 32>; - using BLayout = GMMA::ABLayout<160, 32>; - using CLayout = GMMA::CLayout_64x160; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) - - -using SM90_64x176x32_S32U8S8_SS_TN = SM90::GMMA::MMA_64x176x32_S32U8S8_SS_TN; - -template <> -struct MMA_Traits -{ - using ValTypeD = int32_t; - using ValTypeA = uint8_t; - using ValTypeB = int8_t; - using ValTypeC = int32_t; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_176,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 32>; - using BLayout = GMMA::ABLayout<176, 32>; - using CLayout = GMMA::CLayout_64x176; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) - - -using SM90_64x176x32_S32U8S8_SS_TN_SATURATE = SM90::GMMA::MMA_64x176x32_S32U8S8_SS_TN_SATURATE; - -template <> -struct MMA_Traits -{ - using ValTypeD = int32_t; - using ValTypeA = uint8_t; - using ValTypeB = int8_t; - using ValTypeC = int32_t; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_176,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 32>; - using BLayout = GMMA::ABLayout<176, 32>; - using CLayout = GMMA::CLayout_64x176; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - - - -using SM90_64x192x32_S32U8S8_SS_TN = SM90::GMMA::MMA_64x192x32_S32U8S8_SS_TN; - -template <> -struct MMA_Traits -{ - using ValTypeD = int32_t; - using ValTypeA = uint8_t; - using ValTypeB = int8_t; - using ValTypeC = int32_t; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_192,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 32>; - using BLayout = GMMA::ABLayout<192, 32>; - using CLayout = GMMA::CLayout_64x192; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - - - -using SM90_64x192x32_S32U8S8_SS_TN_SATURATE = SM90::GMMA::MMA_64x192x32_S32U8S8_SS_TN_SATURATE; - -template <> -struct MMA_Traits -{ - using ValTypeD = int32_t; - using ValTypeA = uint8_t; - using ValTypeB = int8_t; - using ValTypeC = int32_t; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_192,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 32>; - using BLayout = GMMA::ABLayout<192, 32>; - using CLayout = GMMA::CLayout_64x192; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) - - -using SM90_64x208x32_S32U8S8_SS_TN = SM90::GMMA::MMA_64x208x32_S32U8S8_SS_TN; - -template <> -struct MMA_Traits -{ - using ValTypeD = int32_t; - using ValTypeA = uint8_t; - using ValTypeB = int8_t; - using ValTypeC = int32_t; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_208,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 32>; - using BLayout = GMMA::ABLayout<208, 32>; - using CLayout = GMMA::CLayout_64x208; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) - - -using SM90_64x208x32_S32U8S8_SS_TN_SATURATE = SM90::GMMA::MMA_64x208x32_S32U8S8_SS_TN_SATURATE; - -template <> -struct MMA_Traits -{ - using ValTypeD = int32_t; - using ValTypeA = uint8_t; - using ValTypeB = int8_t; - using ValTypeC = int32_t; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_208,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 32>; - using BLayout = GMMA::ABLayout<208, 32>; - using CLayout = GMMA::CLayout_64x208; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) - - -using SM90_64x224x32_S32U8S8_SS_TN = SM90::GMMA::MMA_64x224x32_S32U8S8_SS_TN; - -template <> -struct MMA_Traits -{ - using ValTypeD = int32_t; - using ValTypeA = uint8_t; - using ValTypeB = int8_t; - using ValTypeC = int32_t; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_224,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 32>; - using BLayout = GMMA::ABLayout<224, 32>; - using CLayout = GMMA::CLayout_64x224; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) - - -using SM90_64x224x32_S32U8S8_SS_TN_SATURATE = SM90::GMMA::MMA_64x224x32_S32U8S8_SS_TN_SATURATE; - -template <> -struct MMA_Traits -{ - using ValTypeD = int32_t; - using ValTypeA = uint8_t; - using ValTypeB = int8_t; - using ValTypeC = int32_t; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_224,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 32>; - using BLayout = GMMA::ABLayout<224, 32>; - using CLayout = GMMA::CLayout_64x224; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) - - -using SM90_64x240x32_S32U8S8_SS_TN = SM90::GMMA::MMA_64x240x32_S32U8S8_SS_TN; - -template <> -struct MMA_Traits -{ - using ValTypeD = int32_t; - using ValTypeA = uint8_t; - using ValTypeB = int8_t; - using ValTypeC = int32_t; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_240,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 32>; - using BLayout = GMMA::ABLayout<240, 32>; - using CLayout = GMMA::CLayout_64x240; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) - - -using SM90_64x240x32_S32U8S8_SS_TN_SATURATE = SM90::GMMA::MMA_64x240x32_S32U8S8_SS_TN_SATURATE; - -template <> -struct MMA_Traits -{ - using ValTypeD = int32_t; - using ValTypeA = uint8_t; - using ValTypeB = int8_t; - using ValTypeC = int32_t; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_240,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 32>; - using BLayout = GMMA::ABLayout<240, 32>; - using CLayout = GMMA::CLayout_64x240; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - - - -using SM90_64x256x32_S32U8S8_SS_TN = SM90::GMMA::MMA_64x256x32_S32U8S8_SS_TN; - -template <> -struct MMA_Traits -{ - using ValTypeD = int32_t; - using ValTypeA = uint8_t; - using ValTypeB = int8_t; - using ValTypeC = int32_t; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_256,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 32>; - using BLayout = GMMA::ABLayout<256, 32>; - using CLayout = GMMA::CLayout_64x256; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - - - -using SM90_64x256x32_S32U8S8_SS_TN_SATURATE = SM90::GMMA::MMA_64x256x32_S32U8S8_SS_TN_SATURATE; - -template <> -struct MMA_Traits -{ - using ValTypeD = int32_t; - using ValTypeA = uint8_t; - using ValTypeB = int8_t; - using ValTypeC = int32_t; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_256,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 32>; - using BLayout = GMMA::ABLayout<256, 32>; - using CLayout = GMMA::CLayout_64x256; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - - - -using SM90_64x8x32_S32U8S8_RS_TN = SM90::GMMA::MMA_64x8x32_S32U8S8_RS_TN; - -template <> -struct MMA_Traits -{ - using ValTypeD = int32_t; - using ValTypeA = uint8_t; - using ValTypeB = int8_t; - using ValTypeC = int32_t; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_8,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x32; - using BLayout = GMMA::ABLayout< 8, 32>; - using CLayout = GMMA::CLayout_64x8; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - - - -using SM90_64x8x32_S32U8S8_RS_TN_SATURATE = SM90::GMMA::MMA_64x8x32_S32U8S8_RS_TN_SATURATE; - -template <> -struct MMA_Traits -{ - using ValTypeD = int32_t; - using ValTypeA = uint8_t; - using ValTypeB = int8_t; - using ValTypeC = int32_t; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_8,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x32; - using BLayout = GMMA::ABLayout< 8, 32>; - using CLayout = GMMA::CLayout_64x8; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - - - -using SM90_64x16x32_S32U8S8_RS_TN = SM90::GMMA::MMA_64x16x32_S32U8S8_RS_TN; - -template <> -struct MMA_Traits -{ - using ValTypeD = int32_t; - using ValTypeA = uint8_t; - using ValTypeB = int8_t; - using ValTypeC = int32_t; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_16,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x32; - using BLayout = GMMA::ABLayout< 16, 32>; - using CLayout = GMMA::CLayout_64x16; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - - - -using SM90_64x16x32_S32U8S8_RS_TN_SATURATE = SM90::GMMA::MMA_64x16x32_S32U8S8_RS_TN_SATURATE; - -template <> -struct MMA_Traits -{ - using ValTypeD = int32_t; - using ValTypeA = uint8_t; - using ValTypeB = int8_t; - using ValTypeC = int32_t; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_16,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x32; - using BLayout = GMMA::ABLayout< 16, 32>; - using CLayout = GMMA::CLayout_64x16; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - - - -using SM90_64x32x32_S32U8S8_RS_TN = SM90::GMMA::MMA_64x32x32_S32U8S8_RS_TN; - -template <> -struct MMA_Traits -{ - using ValTypeD = int32_t; - using ValTypeA = uint8_t; - using ValTypeB = int8_t; - using ValTypeC = int32_t; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_32,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x32; - using BLayout = GMMA::ABLayout< 32, 32>; - using CLayout = GMMA::CLayout_64x32; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - - - -using SM90_64x32x32_S32U8S8_RS_TN_SATURATE = SM90::GMMA::MMA_64x32x32_S32U8S8_RS_TN_SATURATE; - -template <> -struct MMA_Traits -{ - using ValTypeD = int32_t; - using ValTypeA = uint8_t; - using ValTypeB = int8_t; - using ValTypeC = int32_t; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_32,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x32; - using BLayout = GMMA::ABLayout< 32, 32>; - using CLayout = GMMA::CLayout_64x32; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) - - -using SM90_64x48x32_S32U8S8_RS_TN = SM90::GMMA::MMA_64x48x32_S32U8S8_RS_TN; - -template <> -struct MMA_Traits -{ - using ValTypeD = int32_t; - using ValTypeA = uint8_t; - using ValTypeB = int8_t; - using ValTypeC = int32_t; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_48,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x32; - using BLayout = GMMA::ABLayout< 48, 32>; - using CLayout = GMMA::CLayout_64x48; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) - - -using SM90_64x48x32_S32U8S8_RS_TN_SATURATE = SM90::GMMA::MMA_64x48x32_S32U8S8_RS_TN_SATURATE; - -template <> -struct MMA_Traits -{ - using ValTypeD = int32_t; - using ValTypeA = uint8_t; - using ValTypeB = int8_t; - using ValTypeC = int32_t; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_48,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x32; - using BLayout = GMMA::ABLayout< 48, 32>; - using CLayout = GMMA::CLayout_64x48; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - - - -using SM90_64x64x32_S32U8S8_RS_TN = SM90::GMMA::MMA_64x64x32_S32U8S8_RS_TN; - -template <> -struct MMA_Traits -{ - using ValTypeD = int32_t; - using ValTypeA = uint8_t; - using ValTypeB = int8_t; - using ValTypeC = int32_t; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_64,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x32; - using BLayout = GMMA::ABLayout< 64, 32>; - using CLayout = GMMA::CLayout_64x64; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - - - -using SM90_64x64x32_S32U8S8_RS_TN_SATURATE = SM90::GMMA::MMA_64x64x32_S32U8S8_RS_TN_SATURATE; - -template <> -struct MMA_Traits -{ - using ValTypeD = int32_t; - using ValTypeA = uint8_t; - using ValTypeB = int8_t; - using ValTypeC = int32_t; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_64,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x32; - using BLayout = GMMA::ABLayout< 64, 32>; - using CLayout = GMMA::CLayout_64x64; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) - - -using SM90_64x80x32_S32U8S8_RS_TN = SM90::GMMA::MMA_64x80x32_S32U8S8_RS_TN; - -template <> -struct MMA_Traits -{ - using ValTypeD = int32_t; - using ValTypeA = uint8_t; - using ValTypeB = int8_t; - using ValTypeC = int32_t; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_80,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x32; - using BLayout = GMMA::ABLayout< 80, 32>; - using CLayout = GMMA::CLayout_64x80; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) - - -using SM90_64x80x32_S32U8S8_RS_TN_SATURATE = SM90::GMMA::MMA_64x80x32_S32U8S8_RS_TN_SATURATE; - -template <> -struct MMA_Traits -{ - using ValTypeD = int32_t; - using ValTypeA = uint8_t; - using ValTypeB = int8_t; - using ValTypeC = int32_t; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_80,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x32; - using BLayout = GMMA::ABLayout< 80, 32>; - using CLayout = GMMA::CLayout_64x80; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - - - -using SM90_64x96x32_S32U8S8_RS_TN = SM90::GMMA::MMA_64x96x32_S32U8S8_RS_TN; - -template <> -struct MMA_Traits -{ - using ValTypeD = int32_t; - using ValTypeA = uint8_t; - using ValTypeB = int8_t; - using ValTypeC = int32_t; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_96,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x32; - using BLayout = GMMA::ABLayout< 96, 32>; - using CLayout = GMMA::CLayout_64x96; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - - - -using SM90_64x96x32_S32U8S8_RS_TN_SATURATE = SM90::GMMA::MMA_64x96x32_S32U8S8_RS_TN_SATURATE; - -template <> -struct MMA_Traits -{ - using ValTypeD = int32_t; - using ValTypeA = uint8_t; - using ValTypeB = int8_t; - using ValTypeC = int32_t; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_96,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x32; - using BLayout = GMMA::ABLayout< 96, 32>; - using CLayout = GMMA::CLayout_64x96; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) - - -using SM90_64x112x32_S32U8S8_RS_TN = SM90::GMMA::MMA_64x112x32_S32U8S8_RS_TN; - -template <> -struct MMA_Traits -{ - using ValTypeD = int32_t; - using ValTypeA = uint8_t; - using ValTypeB = int8_t; - using ValTypeC = int32_t; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_112,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x32; - using BLayout = GMMA::ABLayout<112, 32>; - using CLayout = GMMA::CLayout_64x112; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) - - -using SM90_64x112x32_S32U8S8_RS_TN_SATURATE = SM90::GMMA::MMA_64x112x32_S32U8S8_RS_TN_SATURATE; - -template <> -struct MMA_Traits -{ - using ValTypeD = int32_t; - using ValTypeA = uint8_t; - using ValTypeB = int8_t; - using ValTypeC = int32_t; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_112,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x32; - using BLayout = GMMA::ABLayout<112, 32>; - using CLayout = GMMA::CLayout_64x112; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - - - -using SM90_64x128x32_S32U8S8_RS_TN = SM90::GMMA::MMA_64x128x32_S32U8S8_RS_TN; - -template <> -struct MMA_Traits -{ - using ValTypeD = int32_t; - using ValTypeA = uint8_t; - using ValTypeB = int8_t; - using ValTypeC = int32_t; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_128,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x32; - using BLayout = GMMA::ABLayout<128, 32>; - using CLayout = GMMA::CLayout_64x128; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - - - -using SM90_64x128x32_S32U8S8_RS_TN_SATURATE = SM90::GMMA::MMA_64x128x32_S32U8S8_RS_TN_SATURATE; - -template <> -struct MMA_Traits -{ - using ValTypeD = int32_t; - using ValTypeA = uint8_t; - using ValTypeB = int8_t; - using ValTypeC = int32_t; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_128,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x32; - using BLayout = GMMA::ABLayout<128, 32>; - using CLayout = GMMA::CLayout_64x128; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) - - -using SM90_64x144x32_S32U8S8_RS_TN = SM90::GMMA::MMA_64x144x32_S32U8S8_RS_TN; - -template <> -struct MMA_Traits -{ - using ValTypeD = int32_t; - using ValTypeA = uint8_t; - using ValTypeB = int8_t; - using ValTypeC = int32_t; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_144,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x32; - using BLayout = GMMA::ABLayout<144, 32>; - using CLayout = GMMA::CLayout_64x144; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) - - -using SM90_64x144x32_S32U8S8_RS_TN_SATURATE = SM90::GMMA::MMA_64x144x32_S32U8S8_RS_TN_SATURATE; - -template <> -struct MMA_Traits -{ - using ValTypeD = int32_t; - using ValTypeA = uint8_t; - using ValTypeB = int8_t; - using ValTypeC = int32_t; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_144,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x32; - using BLayout = GMMA::ABLayout<144, 32>; - using CLayout = GMMA::CLayout_64x144; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) - - -using SM90_64x160x32_S32U8S8_RS_TN = SM90::GMMA::MMA_64x160x32_S32U8S8_RS_TN; - -template <> -struct MMA_Traits -{ - using ValTypeD = int32_t; - using ValTypeA = uint8_t; - using ValTypeB = int8_t; - using ValTypeC = int32_t; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_160,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x32; - using BLayout = GMMA::ABLayout<160, 32>; - using CLayout = GMMA::CLayout_64x160; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) - - -using SM90_64x160x32_S32U8S8_RS_TN_SATURATE = SM90::GMMA::MMA_64x160x32_S32U8S8_RS_TN_SATURATE; - -template <> -struct MMA_Traits -{ - using ValTypeD = int32_t; - using ValTypeA = uint8_t; - using ValTypeB = int8_t; - using ValTypeC = int32_t; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_160,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x32; - using BLayout = GMMA::ABLayout<160, 32>; - using CLayout = GMMA::CLayout_64x160; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) - - -using SM90_64x176x32_S32U8S8_RS_TN = SM90::GMMA::MMA_64x176x32_S32U8S8_RS_TN; - -template <> -struct MMA_Traits -{ - using ValTypeD = int32_t; - using ValTypeA = uint8_t; - using ValTypeB = int8_t; - using ValTypeC = int32_t; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_176,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x32; - using BLayout = GMMA::ABLayout<176, 32>; - using CLayout = GMMA::CLayout_64x176; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) - - -using SM90_64x176x32_S32U8S8_RS_TN_SATURATE = SM90::GMMA::MMA_64x176x32_S32U8S8_RS_TN_SATURATE; - -template <> -struct MMA_Traits -{ - using ValTypeD = int32_t; - using ValTypeA = uint8_t; - using ValTypeB = int8_t; - using ValTypeC = int32_t; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_176,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x32; - using BLayout = GMMA::ABLayout<176, 32>; - using CLayout = GMMA::CLayout_64x176; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - - - -using SM90_64x192x32_S32U8S8_RS_TN = SM90::GMMA::MMA_64x192x32_S32U8S8_RS_TN; - -template <> -struct MMA_Traits -{ - using ValTypeD = int32_t; - using ValTypeA = uint8_t; - using ValTypeB = int8_t; - using ValTypeC = int32_t; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_192,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x32; - using BLayout = GMMA::ABLayout<192, 32>; - using CLayout = GMMA::CLayout_64x192; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - - - -using SM90_64x192x32_S32U8S8_RS_TN_SATURATE = SM90::GMMA::MMA_64x192x32_S32U8S8_RS_TN_SATURATE; - -template <> -struct MMA_Traits -{ - using ValTypeD = int32_t; - using ValTypeA = uint8_t; - using ValTypeB = int8_t; - using ValTypeC = int32_t; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_192,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x32; - using BLayout = GMMA::ABLayout<192, 32>; - using CLayout = GMMA::CLayout_64x192; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) - - -using SM90_64x208x32_S32U8S8_RS_TN = SM90::GMMA::MMA_64x208x32_S32U8S8_RS_TN; - -template <> -struct MMA_Traits -{ - using ValTypeD = int32_t; - using ValTypeA = uint8_t; - using ValTypeB = int8_t; - using ValTypeC = int32_t; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_208,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x32; - using BLayout = GMMA::ABLayout<208, 32>; - using CLayout = GMMA::CLayout_64x208; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) - - -using SM90_64x208x32_S32U8S8_RS_TN_SATURATE = SM90::GMMA::MMA_64x208x32_S32U8S8_RS_TN_SATURATE; - -template <> -struct MMA_Traits -{ - using ValTypeD = int32_t; - using ValTypeA = uint8_t; - using ValTypeB = int8_t; - using ValTypeC = int32_t; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_208,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x32; - using BLayout = GMMA::ABLayout<208, 32>; - using CLayout = GMMA::CLayout_64x208; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) - - -using SM90_64x224x32_S32U8S8_RS_TN = SM90::GMMA::MMA_64x224x32_S32U8S8_RS_TN; - -template <> -struct MMA_Traits -{ - using ValTypeD = int32_t; - using ValTypeA = uint8_t; - using ValTypeB = int8_t; - using ValTypeC = int32_t; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_224,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x32; - using BLayout = GMMA::ABLayout<224, 32>; - using CLayout = GMMA::CLayout_64x224; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) - - -using SM90_64x224x32_S32U8S8_RS_TN_SATURATE = SM90::GMMA::MMA_64x224x32_S32U8S8_RS_TN_SATURATE; - -template <> -struct MMA_Traits -{ - using ValTypeD = int32_t; - using ValTypeA = uint8_t; - using ValTypeB = int8_t; - using ValTypeC = int32_t; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_224,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x32; - using BLayout = GMMA::ABLayout<224, 32>; - using CLayout = GMMA::CLayout_64x224; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) - - -using SM90_64x240x32_S32U8S8_RS_TN = SM90::GMMA::MMA_64x240x32_S32U8S8_RS_TN; - -template <> -struct MMA_Traits -{ - using ValTypeD = int32_t; - using ValTypeA = uint8_t; - using ValTypeB = int8_t; - using ValTypeC = int32_t; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_240,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x32; - using BLayout = GMMA::ABLayout<240, 32>; - using CLayout = GMMA::CLayout_64x240; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) - - -using SM90_64x240x32_S32U8S8_RS_TN_SATURATE = SM90::GMMA::MMA_64x240x32_S32U8S8_RS_TN_SATURATE; - -template <> -struct MMA_Traits -{ - using ValTypeD = int32_t; - using ValTypeA = uint8_t; - using ValTypeB = int8_t; - using ValTypeC = int32_t; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_240,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x32; - using BLayout = GMMA::ABLayout<240, 32>; - using CLayout = GMMA::CLayout_64x240; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - - - -using SM90_64x256x32_S32U8S8_RS_TN = SM90::GMMA::MMA_64x256x32_S32U8S8_RS_TN; - -template <> -struct MMA_Traits -{ - using ValTypeD = int32_t; - using ValTypeA = uint8_t; - using ValTypeB = int8_t; - using ValTypeC = int32_t; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_256,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x32; - using BLayout = GMMA::ABLayout<256, 32>; - using CLayout = GMMA::CLayout_64x256; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - - - -using SM90_64x256x32_S32U8S8_RS_TN_SATURATE = SM90::GMMA::MMA_64x256x32_S32U8S8_RS_TN_SATURATE; - -template <> -struct MMA_Traits -{ - using ValTypeD = int32_t; - using ValTypeA = uint8_t; - using ValTypeB = int8_t; - using ValTypeC = int32_t; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_256,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x32; - using BLayout = GMMA::ABLayout<256, 32>; - using CLayout = GMMA::CLayout_64x256; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - - - -using SM90_64x8x32_S32U8U8_SS_TN = SM90::GMMA::MMA_64x8x32_S32U8U8_SS_TN; - -template <> -struct MMA_Traits -{ - using ValTypeD = int32_t; - using ValTypeA = uint8_t; - using ValTypeB = uint8_t; - using ValTypeC = int32_t; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_8,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 32>; - using BLayout = GMMA::ABLayout< 8, 32>; - using CLayout = GMMA::CLayout_64x8; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - - - -using SM90_64x8x32_S32U8U8_SS_TN_SATURATE = SM90::GMMA::MMA_64x8x32_S32U8U8_SS_TN_SATURATE; - -template <> -struct MMA_Traits -{ - using ValTypeD = int32_t; - using ValTypeA = uint8_t; - using ValTypeB = uint8_t; - using ValTypeC = int32_t; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_8,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 32>; - using BLayout = GMMA::ABLayout< 8, 32>; - using CLayout = GMMA::CLayout_64x8; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - - - -using SM90_64x16x32_S32U8U8_SS_TN = SM90::GMMA::MMA_64x16x32_S32U8U8_SS_TN; - -template <> -struct MMA_Traits -{ - using ValTypeD = int32_t; - using ValTypeA = uint8_t; - using ValTypeB = uint8_t; - using ValTypeC = int32_t; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_16,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 32>; - using BLayout = GMMA::ABLayout< 16, 32>; - using CLayout = GMMA::CLayout_64x16; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - - - -using SM90_64x16x32_S32U8U8_SS_TN_SATURATE = SM90::GMMA::MMA_64x16x32_S32U8U8_SS_TN_SATURATE; - -template <> -struct MMA_Traits -{ - using ValTypeD = int32_t; - using ValTypeA = uint8_t; - using ValTypeB = uint8_t; - using ValTypeC = int32_t; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_16,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 32>; - using BLayout = GMMA::ABLayout< 16, 32>; - using CLayout = GMMA::CLayout_64x16; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - - - -using SM90_64x32x32_S32U8U8_SS_TN = SM90::GMMA::MMA_64x32x32_S32U8U8_SS_TN; - -template <> -struct MMA_Traits -{ - using ValTypeD = int32_t; - using ValTypeA = uint8_t; - using ValTypeB = uint8_t; - using ValTypeC = int32_t; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_32,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 32>; - using BLayout = GMMA::ABLayout< 32, 32>; - using CLayout = GMMA::CLayout_64x32; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - - - -using SM90_64x32x32_S32U8U8_SS_TN_SATURATE = SM90::GMMA::MMA_64x32x32_S32U8U8_SS_TN_SATURATE; - -template <> -struct MMA_Traits -{ - using ValTypeD = int32_t; - using ValTypeA = uint8_t; - using ValTypeB = uint8_t; - using ValTypeC = int32_t; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_32,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 32>; - using BLayout = GMMA::ABLayout< 32, 32>; - using CLayout = GMMA::CLayout_64x32; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) - - -using SM90_64x48x32_S32U8U8_SS_TN = SM90::GMMA::MMA_64x48x32_S32U8U8_SS_TN; - -template <> -struct MMA_Traits -{ - using ValTypeD = int32_t; - using ValTypeA = uint8_t; - using ValTypeB = uint8_t; - using ValTypeC = int32_t; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_48,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 32>; - using BLayout = GMMA::ABLayout< 48, 32>; - using CLayout = GMMA::CLayout_64x48; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) - - -using SM90_64x48x32_S32U8U8_SS_TN_SATURATE = SM90::GMMA::MMA_64x48x32_S32U8U8_SS_TN_SATURATE; - -template <> -struct MMA_Traits -{ - using ValTypeD = int32_t; - using ValTypeA = uint8_t; - using ValTypeB = uint8_t; - using ValTypeC = int32_t; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_48,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 32>; - using BLayout = GMMA::ABLayout< 48, 32>; - using CLayout = GMMA::CLayout_64x48; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - - - -using SM90_64x64x32_S32U8U8_SS_TN = SM90::GMMA::MMA_64x64x32_S32U8U8_SS_TN; - -template <> -struct MMA_Traits -{ - using ValTypeD = int32_t; - using ValTypeA = uint8_t; - using ValTypeB = uint8_t; - using ValTypeC = int32_t; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_64,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 32>; - using BLayout = GMMA::ABLayout< 64, 32>; - using CLayout = GMMA::CLayout_64x64; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - - - -using SM90_64x64x32_S32U8U8_SS_TN_SATURATE = SM90::GMMA::MMA_64x64x32_S32U8U8_SS_TN_SATURATE; - -template <> -struct MMA_Traits -{ - using ValTypeD = int32_t; - using ValTypeA = uint8_t; - using ValTypeB = uint8_t; - using ValTypeC = int32_t; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_64,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 32>; - using BLayout = GMMA::ABLayout< 64, 32>; - using CLayout = GMMA::CLayout_64x64; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) - - -using SM90_64x80x32_S32U8U8_SS_TN = SM90::GMMA::MMA_64x80x32_S32U8U8_SS_TN; - -template <> -struct MMA_Traits -{ - using ValTypeD = int32_t; - using ValTypeA = uint8_t; - using ValTypeB = uint8_t; - using ValTypeC = int32_t; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_80,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 32>; - using BLayout = GMMA::ABLayout< 80, 32>; - using CLayout = GMMA::CLayout_64x80; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) - - -using SM90_64x80x32_S32U8U8_SS_TN_SATURATE = SM90::GMMA::MMA_64x80x32_S32U8U8_SS_TN_SATURATE; - -template <> -struct MMA_Traits -{ - using ValTypeD = int32_t; - using ValTypeA = uint8_t; - using ValTypeB = uint8_t; - using ValTypeC = int32_t; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_80,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 32>; - using BLayout = GMMA::ABLayout< 80, 32>; - using CLayout = GMMA::CLayout_64x80; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - - - -using SM90_64x96x32_S32U8U8_SS_TN = SM90::GMMA::MMA_64x96x32_S32U8U8_SS_TN; - -template <> -struct MMA_Traits -{ - using ValTypeD = int32_t; - using ValTypeA = uint8_t; - using ValTypeB = uint8_t; - using ValTypeC = int32_t; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_96,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 32>; - using BLayout = GMMA::ABLayout< 96, 32>; - using CLayout = GMMA::CLayout_64x96; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - - - -using SM90_64x96x32_S32U8U8_SS_TN_SATURATE = SM90::GMMA::MMA_64x96x32_S32U8U8_SS_TN_SATURATE; - -template <> -struct MMA_Traits -{ - using ValTypeD = int32_t; - using ValTypeA = uint8_t; - using ValTypeB = uint8_t; - using ValTypeC = int32_t; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_96,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 32>; - using BLayout = GMMA::ABLayout< 96, 32>; - using CLayout = GMMA::CLayout_64x96; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) - - -using SM90_64x112x32_S32U8U8_SS_TN = SM90::GMMA::MMA_64x112x32_S32U8U8_SS_TN; - -template <> -struct MMA_Traits -{ - using ValTypeD = int32_t; - using ValTypeA = uint8_t; - using ValTypeB = uint8_t; - using ValTypeC = int32_t; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_112,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 32>; - using BLayout = GMMA::ABLayout<112, 32>; - using CLayout = GMMA::CLayout_64x112; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) - - -using SM90_64x112x32_S32U8U8_SS_TN_SATURATE = SM90::GMMA::MMA_64x112x32_S32U8U8_SS_TN_SATURATE; - -template <> -struct MMA_Traits -{ - using ValTypeD = int32_t; - using ValTypeA = uint8_t; - using ValTypeB = uint8_t; - using ValTypeC = int32_t; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_112,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 32>; - using BLayout = GMMA::ABLayout<112, 32>; - using CLayout = GMMA::CLayout_64x112; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - - - -using SM90_64x128x32_S32U8U8_SS_TN = SM90::GMMA::MMA_64x128x32_S32U8U8_SS_TN; - -template <> -struct MMA_Traits -{ - using ValTypeD = int32_t; - using ValTypeA = uint8_t; - using ValTypeB = uint8_t; - using ValTypeC = int32_t; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_128,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 32>; - using BLayout = GMMA::ABLayout<128, 32>; - using CLayout = GMMA::CLayout_64x128; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - - - -using SM90_64x128x32_S32U8U8_SS_TN_SATURATE = SM90::GMMA::MMA_64x128x32_S32U8U8_SS_TN_SATURATE; - -template <> -struct MMA_Traits -{ - using ValTypeD = int32_t; - using ValTypeA = uint8_t; - using ValTypeB = uint8_t; - using ValTypeC = int32_t; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_128,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 32>; - using BLayout = GMMA::ABLayout<128, 32>; - using CLayout = GMMA::CLayout_64x128; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) - - -using SM90_64x144x32_S32U8U8_SS_TN = SM90::GMMA::MMA_64x144x32_S32U8U8_SS_TN; - -template <> -struct MMA_Traits -{ - using ValTypeD = int32_t; - using ValTypeA = uint8_t; - using ValTypeB = uint8_t; - using ValTypeC = int32_t; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_144,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 32>; - using BLayout = GMMA::ABLayout<144, 32>; - using CLayout = GMMA::CLayout_64x144; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) - - -using SM90_64x144x32_S32U8U8_SS_TN_SATURATE = SM90::GMMA::MMA_64x144x32_S32U8U8_SS_TN_SATURATE; - -template <> -struct MMA_Traits -{ - using ValTypeD = int32_t; - using ValTypeA = uint8_t; - using ValTypeB = uint8_t; - using ValTypeC = int32_t; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_144,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 32>; - using BLayout = GMMA::ABLayout<144, 32>; - using CLayout = GMMA::CLayout_64x144; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) - - -using SM90_64x160x32_S32U8U8_SS_TN = SM90::GMMA::MMA_64x160x32_S32U8U8_SS_TN; - -template <> -struct MMA_Traits -{ - using ValTypeD = int32_t; - using ValTypeA = uint8_t; - using ValTypeB = uint8_t; - using ValTypeC = int32_t; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_160,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 32>; - using BLayout = GMMA::ABLayout<160, 32>; - using CLayout = GMMA::CLayout_64x160; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) - - -using SM90_64x160x32_S32U8U8_SS_TN_SATURATE = SM90::GMMA::MMA_64x160x32_S32U8U8_SS_TN_SATURATE; - -template <> -struct MMA_Traits -{ - using ValTypeD = int32_t; - using ValTypeA = uint8_t; - using ValTypeB = uint8_t; - using ValTypeC = int32_t; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_160,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 32>; - using BLayout = GMMA::ABLayout<160, 32>; - using CLayout = GMMA::CLayout_64x160; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) - - -using SM90_64x176x32_S32U8U8_SS_TN = SM90::GMMA::MMA_64x176x32_S32U8U8_SS_TN; - -template <> -struct MMA_Traits -{ - using ValTypeD = int32_t; - using ValTypeA = uint8_t; - using ValTypeB = uint8_t; - using ValTypeC = int32_t; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_176,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 32>; - using BLayout = GMMA::ABLayout<176, 32>; - using CLayout = GMMA::CLayout_64x176; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) - - -using SM90_64x176x32_S32U8U8_SS_TN_SATURATE = SM90::GMMA::MMA_64x176x32_S32U8U8_SS_TN_SATURATE; - -template <> -struct MMA_Traits -{ - using ValTypeD = int32_t; - using ValTypeA = uint8_t; - using ValTypeB = uint8_t; - using ValTypeC = int32_t; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_176,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 32>; - using BLayout = GMMA::ABLayout<176, 32>; - using CLayout = GMMA::CLayout_64x176; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - - - -using SM90_64x192x32_S32U8U8_SS_TN = SM90::GMMA::MMA_64x192x32_S32U8U8_SS_TN; - -template <> -struct MMA_Traits -{ - using ValTypeD = int32_t; - using ValTypeA = uint8_t; - using ValTypeB = uint8_t; - using ValTypeC = int32_t; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_192,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 32>; - using BLayout = GMMA::ABLayout<192, 32>; - using CLayout = GMMA::CLayout_64x192; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - - - -using SM90_64x192x32_S32U8U8_SS_TN_SATURATE = SM90::GMMA::MMA_64x192x32_S32U8U8_SS_TN_SATURATE; - -template <> -struct MMA_Traits -{ - using ValTypeD = int32_t; - using ValTypeA = uint8_t; - using ValTypeB = uint8_t; - using ValTypeC = int32_t; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_192,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 32>; - using BLayout = GMMA::ABLayout<192, 32>; - using CLayout = GMMA::CLayout_64x192; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) - - -using SM90_64x208x32_S32U8U8_SS_TN = SM90::GMMA::MMA_64x208x32_S32U8U8_SS_TN; - -template <> -struct MMA_Traits -{ - using ValTypeD = int32_t; - using ValTypeA = uint8_t; - using ValTypeB = uint8_t; - using ValTypeC = int32_t; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_208,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 32>; - using BLayout = GMMA::ABLayout<208, 32>; - using CLayout = GMMA::CLayout_64x208; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) - - -using SM90_64x208x32_S32U8U8_SS_TN_SATURATE = SM90::GMMA::MMA_64x208x32_S32U8U8_SS_TN_SATURATE; - -template <> -struct MMA_Traits -{ - using ValTypeD = int32_t; - using ValTypeA = uint8_t; - using ValTypeB = uint8_t; - using ValTypeC = int32_t; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_208,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 32>; - using BLayout = GMMA::ABLayout<208, 32>; - using CLayout = GMMA::CLayout_64x208; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) - - -using SM90_64x224x32_S32U8U8_SS_TN = SM90::GMMA::MMA_64x224x32_S32U8U8_SS_TN; - -template <> -struct MMA_Traits -{ - using ValTypeD = int32_t; - using ValTypeA = uint8_t; - using ValTypeB = uint8_t; - using ValTypeC = int32_t; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_224,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 32>; - using BLayout = GMMA::ABLayout<224, 32>; - using CLayout = GMMA::CLayout_64x224; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) - - -using SM90_64x224x32_S32U8U8_SS_TN_SATURATE = SM90::GMMA::MMA_64x224x32_S32U8U8_SS_TN_SATURATE; - -template <> -struct MMA_Traits -{ - using ValTypeD = int32_t; - using ValTypeA = uint8_t; - using ValTypeB = uint8_t; - using ValTypeC = int32_t; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_224,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 32>; - using BLayout = GMMA::ABLayout<224, 32>; - using CLayout = GMMA::CLayout_64x224; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) - - -using SM90_64x240x32_S32U8U8_SS_TN = SM90::GMMA::MMA_64x240x32_S32U8U8_SS_TN; - -template <> -struct MMA_Traits -{ - using ValTypeD = int32_t; - using ValTypeA = uint8_t; - using ValTypeB = uint8_t; - using ValTypeC = int32_t; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_240,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 32>; - using BLayout = GMMA::ABLayout<240, 32>; - using CLayout = GMMA::CLayout_64x240; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) - - -using SM90_64x240x32_S32U8U8_SS_TN_SATURATE = SM90::GMMA::MMA_64x240x32_S32U8U8_SS_TN_SATURATE; - -template <> -struct MMA_Traits -{ - using ValTypeD = int32_t; - using ValTypeA = uint8_t; - using ValTypeB = uint8_t; - using ValTypeC = int32_t; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_240,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 32>; - using BLayout = GMMA::ABLayout<240, 32>; - using CLayout = GMMA::CLayout_64x240; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - - - -using SM90_64x256x32_S32U8U8_SS_TN = SM90::GMMA::MMA_64x256x32_S32U8U8_SS_TN; - -template <> -struct MMA_Traits -{ - using ValTypeD = int32_t; - using ValTypeA = uint8_t; - using ValTypeB = uint8_t; - using ValTypeC = int32_t; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_256,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 32>; - using BLayout = GMMA::ABLayout<256, 32>; - using CLayout = GMMA::CLayout_64x256; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - - - -using SM90_64x256x32_S32U8U8_SS_TN_SATURATE = SM90::GMMA::MMA_64x256x32_S32U8U8_SS_TN_SATURATE; - -template <> -struct MMA_Traits -{ - using ValTypeD = int32_t; - using ValTypeA = uint8_t; - using ValTypeB = uint8_t; - using ValTypeC = int32_t; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_256,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 32>; - using BLayout = GMMA::ABLayout<256, 32>; - using CLayout = GMMA::CLayout_64x256; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - - - -using SM90_64x8x32_S32U8U8_RS_TN = SM90::GMMA::MMA_64x8x32_S32U8U8_RS_TN; - -template <> -struct MMA_Traits -{ - using ValTypeD = int32_t; - using ValTypeA = uint8_t; - using ValTypeB = uint8_t; - using ValTypeC = int32_t; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_8,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x32; - using BLayout = GMMA::ABLayout< 8, 32>; - using CLayout = GMMA::CLayout_64x8; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - - - -using SM90_64x8x32_S32U8U8_RS_TN_SATURATE = SM90::GMMA::MMA_64x8x32_S32U8U8_RS_TN_SATURATE; - -template <> -struct MMA_Traits -{ - using ValTypeD = int32_t; - using ValTypeA = uint8_t; - using ValTypeB = uint8_t; - using ValTypeC = int32_t; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_8,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x32; - using BLayout = GMMA::ABLayout< 8, 32>; - using CLayout = GMMA::CLayout_64x8; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - - - -using SM90_64x16x32_S32U8U8_RS_TN = SM90::GMMA::MMA_64x16x32_S32U8U8_RS_TN; - -template <> -struct MMA_Traits -{ - using ValTypeD = int32_t; - using ValTypeA = uint8_t; - using ValTypeB = uint8_t; - using ValTypeC = int32_t; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_16,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x32; - using BLayout = GMMA::ABLayout< 16, 32>; - using CLayout = GMMA::CLayout_64x16; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - - - -using SM90_64x16x32_S32U8U8_RS_TN_SATURATE = SM90::GMMA::MMA_64x16x32_S32U8U8_RS_TN_SATURATE; - -template <> -struct MMA_Traits -{ - using ValTypeD = int32_t; - using ValTypeA = uint8_t; - using ValTypeB = uint8_t; - using ValTypeC = int32_t; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_16,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x32; - using BLayout = GMMA::ABLayout< 16, 32>; - using CLayout = GMMA::CLayout_64x16; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - - - -using SM90_64x32x32_S32U8U8_RS_TN = SM90::GMMA::MMA_64x32x32_S32U8U8_RS_TN; - -template <> -struct MMA_Traits -{ - using ValTypeD = int32_t; - using ValTypeA = uint8_t; - using ValTypeB = uint8_t; - using ValTypeC = int32_t; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_32,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x32; - using BLayout = GMMA::ABLayout< 32, 32>; - using CLayout = GMMA::CLayout_64x32; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - - - -using SM90_64x32x32_S32U8U8_RS_TN_SATURATE = SM90::GMMA::MMA_64x32x32_S32U8U8_RS_TN_SATURATE; - -template <> -struct MMA_Traits -{ - using ValTypeD = int32_t; - using ValTypeA = uint8_t; - using ValTypeB = uint8_t; - using ValTypeC = int32_t; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_32,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x32; - using BLayout = GMMA::ABLayout< 32, 32>; - using CLayout = GMMA::CLayout_64x32; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) - - -using SM90_64x48x32_S32U8U8_RS_TN = SM90::GMMA::MMA_64x48x32_S32U8U8_RS_TN; - -template <> -struct MMA_Traits -{ - using ValTypeD = int32_t; - using ValTypeA = uint8_t; - using ValTypeB = uint8_t; - using ValTypeC = int32_t; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_48,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x32; - using BLayout = GMMA::ABLayout< 48, 32>; - using CLayout = GMMA::CLayout_64x48; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) - - -using SM90_64x48x32_S32U8U8_RS_TN_SATURATE = SM90::GMMA::MMA_64x48x32_S32U8U8_RS_TN_SATURATE; - -template <> -struct MMA_Traits -{ - using ValTypeD = int32_t; - using ValTypeA = uint8_t; - using ValTypeB = uint8_t; - using ValTypeC = int32_t; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_48,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x32; - using BLayout = GMMA::ABLayout< 48, 32>; - using CLayout = GMMA::CLayout_64x48; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - - - -using SM90_64x64x32_S32U8U8_RS_TN = SM90::GMMA::MMA_64x64x32_S32U8U8_RS_TN; - -template <> -struct MMA_Traits -{ - using ValTypeD = int32_t; - using ValTypeA = uint8_t; - using ValTypeB = uint8_t; - using ValTypeC = int32_t; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_64,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x32; - using BLayout = GMMA::ABLayout< 64, 32>; - using CLayout = GMMA::CLayout_64x64; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - - - -using SM90_64x64x32_S32U8U8_RS_TN_SATURATE = SM90::GMMA::MMA_64x64x32_S32U8U8_RS_TN_SATURATE; - -template <> -struct MMA_Traits -{ - using ValTypeD = int32_t; - using ValTypeA = uint8_t; - using ValTypeB = uint8_t; - using ValTypeC = int32_t; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_64,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x32; - using BLayout = GMMA::ABLayout< 64, 32>; - using CLayout = GMMA::CLayout_64x64; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) - - -using SM90_64x80x32_S32U8U8_RS_TN = SM90::GMMA::MMA_64x80x32_S32U8U8_RS_TN; - -template <> -struct MMA_Traits -{ - using ValTypeD = int32_t; - using ValTypeA = uint8_t; - using ValTypeB = uint8_t; - using ValTypeC = int32_t; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_80,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x32; - using BLayout = GMMA::ABLayout< 80, 32>; - using CLayout = GMMA::CLayout_64x80; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) - - -using SM90_64x80x32_S32U8U8_RS_TN_SATURATE = SM90::GMMA::MMA_64x80x32_S32U8U8_RS_TN_SATURATE; - -template <> -struct MMA_Traits -{ - using ValTypeD = int32_t; - using ValTypeA = uint8_t; - using ValTypeB = uint8_t; - using ValTypeC = int32_t; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_80,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x32; - using BLayout = GMMA::ABLayout< 80, 32>; - using CLayout = GMMA::CLayout_64x80; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - - - -using SM90_64x96x32_S32U8U8_RS_TN = SM90::GMMA::MMA_64x96x32_S32U8U8_RS_TN; - -template <> -struct MMA_Traits -{ - using ValTypeD = int32_t; - using ValTypeA = uint8_t; - using ValTypeB = uint8_t; - using ValTypeC = int32_t; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_96,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x32; - using BLayout = GMMA::ABLayout< 96, 32>; - using CLayout = GMMA::CLayout_64x96; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - - - -using SM90_64x96x32_S32U8U8_RS_TN_SATURATE = SM90::GMMA::MMA_64x96x32_S32U8U8_RS_TN_SATURATE; - -template <> -struct MMA_Traits -{ - using ValTypeD = int32_t; - using ValTypeA = uint8_t; - using ValTypeB = uint8_t; - using ValTypeC = int32_t; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_96,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x32; - using BLayout = GMMA::ABLayout< 96, 32>; - using CLayout = GMMA::CLayout_64x96; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) - - -using SM90_64x112x32_S32U8U8_RS_TN = SM90::GMMA::MMA_64x112x32_S32U8U8_RS_TN; - -template <> -struct MMA_Traits -{ - using ValTypeD = int32_t; - using ValTypeA = uint8_t; - using ValTypeB = uint8_t; - using ValTypeC = int32_t; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_112,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x32; - using BLayout = GMMA::ABLayout<112, 32>; - using CLayout = GMMA::CLayout_64x112; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) - - -using SM90_64x112x32_S32U8U8_RS_TN_SATURATE = SM90::GMMA::MMA_64x112x32_S32U8U8_RS_TN_SATURATE; - -template <> -struct MMA_Traits -{ - using ValTypeD = int32_t; - using ValTypeA = uint8_t; - using ValTypeB = uint8_t; - using ValTypeC = int32_t; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_112,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x32; - using BLayout = GMMA::ABLayout<112, 32>; - using CLayout = GMMA::CLayout_64x112; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - - - -using SM90_64x128x32_S32U8U8_RS_TN = SM90::GMMA::MMA_64x128x32_S32U8U8_RS_TN; - -template <> -struct MMA_Traits -{ - using ValTypeD = int32_t; - using ValTypeA = uint8_t; - using ValTypeB = uint8_t; - using ValTypeC = int32_t; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_128,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x32; - using BLayout = GMMA::ABLayout<128, 32>; - using CLayout = GMMA::CLayout_64x128; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - - - -using SM90_64x128x32_S32U8U8_RS_TN_SATURATE = SM90::GMMA::MMA_64x128x32_S32U8U8_RS_TN_SATURATE; - -template <> -struct MMA_Traits -{ - using ValTypeD = int32_t; - using ValTypeA = uint8_t; - using ValTypeB = uint8_t; - using ValTypeC = int32_t; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_128,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x32; - using BLayout = GMMA::ABLayout<128, 32>; - using CLayout = GMMA::CLayout_64x128; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) - - -using SM90_64x144x32_S32U8U8_RS_TN = SM90::GMMA::MMA_64x144x32_S32U8U8_RS_TN; - -template <> -struct MMA_Traits -{ - using ValTypeD = int32_t; - using ValTypeA = uint8_t; - using ValTypeB = uint8_t; - using ValTypeC = int32_t; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_144,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x32; - using BLayout = GMMA::ABLayout<144, 32>; - using CLayout = GMMA::CLayout_64x144; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) - - -using SM90_64x144x32_S32U8U8_RS_TN_SATURATE = SM90::GMMA::MMA_64x144x32_S32U8U8_RS_TN_SATURATE; - -template <> -struct MMA_Traits -{ - using ValTypeD = int32_t; - using ValTypeA = uint8_t; - using ValTypeB = uint8_t; - using ValTypeC = int32_t; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_144,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x32; - using BLayout = GMMA::ABLayout<144, 32>; - using CLayout = GMMA::CLayout_64x144; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) - - -using SM90_64x160x32_S32U8U8_RS_TN = SM90::GMMA::MMA_64x160x32_S32U8U8_RS_TN; - -template <> -struct MMA_Traits -{ - using ValTypeD = int32_t; - using ValTypeA = uint8_t; - using ValTypeB = uint8_t; - using ValTypeC = int32_t; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_160,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x32; - using BLayout = GMMA::ABLayout<160, 32>; - using CLayout = GMMA::CLayout_64x160; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) - - -using SM90_64x160x32_S32U8U8_RS_TN_SATURATE = SM90::GMMA::MMA_64x160x32_S32U8U8_RS_TN_SATURATE; - -template <> -struct MMA_Traits -{ - using ValTypeD = int32_t; - using ValTypeA = uint8_t; - using ValTypeB = uint8_t; - using ValTypeC = int32_t; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_160,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x32; - using BLayout = GMMA::ABLayout<160, 32>; - using CLayout = GMMA::CLayout_64x160; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) - - -using SM90_64x176x32_S32U8U8_RS_TN = SM90::GMMA::MMA_64x176x32_S32U8U8_RS_TN; - -template <> -struct MMA_Traits -{ - using ValTypeD = int32_t; - using ValTypeA = uint8_t; - using ValTypeB = uint8_t; - using ValTypeC = int32_t; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_176,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x32; - using BLayout = GMMA::ABLayout<176, 32>; - using CLayout = GMMA::CLayout_64x176; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) - - -using SM90_64x176x32_S32U8U8_RS_TN_SATURATE = SM90::GMMA::MMA_64x176x32_S32U8U8_RS_TN_SATURATE; - -template <> -struct MMA_Traits -{ - using ValTypeD = int32_t; - using ValTypeA = uint8_t; - using ValTypeB = uint8_t; - using ValTypeC = int32_t; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_176,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x32; - using BLayout = GMMA::ABLayout<176, 32>; - using CLayout = GMMA::CLayout_64x176; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - - - -using SM90_64x192x32_S32U8U8_RS_TN = SM90::GMMA::MMA_64x192x32_S32U8U8_RS_TN; - -template <> -struct MMA_Traits -{ - using ValTypeD = int32_t; - using ValTypeA = uint8_t; - using ValTypeB = uint8_t; - using ValTypeC = int32_t; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_192,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x32; - using BLayout = GMMA::ABLayout<192, 32>; - using CLayout = GMMA::CLayout_64x192; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - - - -using SM90_64x192x32_S32U8U8_RS_TN_SATURATE = SM90::GMMA::MMA_64x192x32_S32U8U8_RS_TN_SATURATE; - -template <> -struct MMA_Traits -{ - using ValTypeD = int32_t; - using ValTypeA = uint8_t; - using ValTypeB = uint8_t; - using ValTypeC = int32_t; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_192,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x32; - using BLayout = GMMA::ABLayout<192, 32>; - using CLayout = GMMA::CLayout_64x192; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) - - -using SM90_64x208x32_S32U8U8_RS_TN = SM90::GMMA::MMA_64x208x32_S32U8U8_RS_TN; - -template <> -struct MMA_Traits -{ - using ValTypeD = int32_t; - using ValTypeA = uint8_t; - using ValTypeB = uint8_t; - using ValTypeC = int32_t; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_208,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x32; - using BLayout = GMMA::ABLayout<208, 32>; - using CLayout = GMMA::CLayout_64x208; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) - - -using SM90_64x208x32_S32U8U8_RS_TN_SATURATE = SM90::GMMA::MMA_64x208x32_S32U8U8_RS_TN_SATURATE; - -template <> -struct MMA_Traits -{ - using ValTypeD = int32_t; - using ValTypeA = uint8_t; - using ValTypeB = uint8_t; - using ValTypeC = int32_t; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_208,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x32; - using BLayout = GMMA::ABLayout<208, 32>; - using CLayout = GMMA::CLayout_64x208; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) - - -using SM90_64x224x32_S32U8U8_RS_TN = SM90::GMMA::MMA_64x224x32_S32U8U8_RS_TN; - -template <> -struct MMA_Traits -{ - using ValTypeD = int32_t; - using ValTypeA = uint8_t; - using ValTypeB = uint8_t; - using ValTypeC = int32_t; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_224,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x32; - using BLayout = GMMA::ABLayout<224, 32>; - using CLayout = GMMA::CLayout_64x224; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) - - -using SM90_64x224x32_S32U8U8_RS_TN_SATURATE = SM90::GMMA::MMA_64x224x32_S32U8U8_RS_TN_SATURATE; - -template <> -struct MMA_Traits -{ - using ValTypeD = int32_t; - using ValTypeA = uint8_t; - using ValTypeB = uint8_t; - using ValTypeC = int32_t; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_224,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x32; - using BLayout = GMMA::ABLayout<224, 32>; - using CLayout = GMMA::CLayout_64x224; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) - - -using SM90_64x240x32_S32U8U8_RS_TN = SM90::GMMA::MMA_64x240x32_S32U8U8_RS_TN; - -template <> -struct MMA_Traits -{ - using ValTypeD = int32_t; - using ValTypeA = uint8_t; - using ValTypeB = uint8_t; - using ValTypeC = int32_t; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_240,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x32; - using BLayout = GMMA::ABLayout<240, 32>; - using CLayout = GMMA::CLayout_64x240; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) - - -using SM90_64x240x32_S32U8U8_RS_TN_SATURATE = SM90::GMMA::MMA_64x240x32_S32U8U8_RS_TN_SATURATE; - -template <> -struct MMA_Traits -{ - using ValTypeD = int32_t; - using ValTypeA = uint8_t; - using ValTypeB = uint8_t; - using ValTypeC = int32_t; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_240,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x32; - using BLayout = GMMA::ABLayout<240, 32>; - using CLayout = GMMA::CLayout_64x240; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - - - -using SM90_64x256x32_S32U8U8_RS_TN = SM90::GMMA::MMA_64x256x32_S32U8U8_RS_TN; - -template <> -struct MMA_Traits -{ - using ValTypeD = int32_t; - using ValTypeA = uint8_t; - using ValTypeB = uint8_t; - using ValTypeC = int32_t; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_256,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x32; - using BLayout = GMMA::ABLayout<256, 32>; - using CLayout = GMMA::CLayout_64x256; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - - - -using SM90_64x256x32_S32U8U8_RS_TN_SATURATE = SM90::GMMA::MMA_64x256x32_S32U8U8_RS_TN_SATURATE; - -template <> -struct MMA_Traits -{ - using ValTypeD = int32_t; - using ValTypeA = uint8_t; - using ValTypeB = uint8_t; - using ValTypeC = int32_t; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_256,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x32; - using BLayout = GMMA::ABLayout<256, 32>; - using CLayout = GMMA::CLayout_64x256; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - - -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -using SM90_64x8x32_F16E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x8x32_F16E4M3E4M3_SS_TN; - -template -struct MMA_Traits> -{ - using ValTypeD = half_t; - using ValTypeA = float_e4m3_t; - using ValTypeB = float_e4m3_t; - using ValTypeC = half_t; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_8,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 32>; - using BLayout = GMMA::ABLayout< 8, 32>; - using CLayout = GMMA::CLayout_64x8; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - - -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -using SM90_64x8x32_F16E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x8x32_F16E4M3E4M3_RS_TN; - -template -struct MMA_Traits> -{ - using ValTypeD = half_t; - using ValTypeA = float_e4m3_t; - using ValTypeB = float_e4m3_t; - using ValTypeC = half_t; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_8,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x32; - using BLayout = GMMA::ABLayout< 8, 32>; - using CLayout = GMMA::CLayout_64x8; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - - -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -using SM90_64x8x32_F32E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x8x32_F32E4M3E4M3_SS_TN; - -template -struct MMA_Traits> -{ - using ValTypeD = float; - using ValTypeA = float_e4m3_t; - using ValTypeB = float_e4m3_t; - using ValTypeC = float; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_8,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 32>; - using BLayout = GMMA::ABLayout< 8, 32>; - using CLayout = GMMA::CLayout_64x8; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - - -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -using SM90_64x8x32_F32E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x8x32_F32E4M3E4M3_RS_TN; - -template -struct MMA_Traits> -{ - using ValTypeD = float; - using ValTypeA = float_e4m3_t; - using ValTypeB = float_e4m3_t; - using ValTypeC = float; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_8,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x32; - using BLayout = GMMA::ABLayout< 8, 32>; - using CLayout = GMMA::CLayout_64x8; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - - -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -using SM90_64x16x32_F16E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x16x32_F16E4M3E4M3_SS_TN; - -template -struct MMA_Traits> -{ - using ValTypeD = half_t; - using ValTypeA = float_e4m3_t; - using ValTypeB = float_e4m3_t; - using ValTypeC = half_t; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_16,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 32>; - using BLayout = GMMA::ABLayout< 16, 32>; - using CLayout = GMMA::CLayout_64x16; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - - -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -using SM90_64x16x32_F16E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x16x32_F16E4M3E4M3_RS_TN; - -template -struct MMA_Traits> -{ - using ValTypeD = half_t; - using ValTypeA = float_e4m3_t; - using ValTypeB = float_e4m3_t; - using ValTypeC = half_t; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_16,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x32; - using BLayout = GMMA::ABLayout< 16, 32>; - using CLayout = GMMA::CLayout_64x16; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - - -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -using SM90_64x16x32_F32E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x16x32_F32E4M3E4M3_SS_TN; - -template -struct MMA_Traits> -{ - using ValTypeD = float; - using ValTypeA = float_e4m3_t; - using ValTypeB = float_e4m3_t; - using ValTypeC = float; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_16,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 32>; - using BLayout = GMMA::ABLayout< 16, 32>; - using CLayout = GMMA::CLayout_64x16; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - - -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -using SM90_64x16x32_F32E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x16x32_F32E4M3E4M3_RS_TN; - -template -struct MMA_Traits> -{ - using ValTypeD = float; - using ValTypeA = float_e4m3_t; - using ValTypeB = float_e4m3_t; - using ValTypeC = float; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_16,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x32; - using BLayout = GMMA::ABLayout< 16, 32>; - using CLayout = GMMA::CLayout_64x16; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - - -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -using SM90_64x32x32_F16E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x32x32_F16E4M3E4M3_SS_TN; - -template -struct MMA_Traits> -{ - using ValTypeD = half_t; - using ValTypeA = float_e4m3_t; - using ValTypeB = float_e4m3_t; - using ValTypeC = half_t; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_32,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 32>; - using BLayout = GMMA::ABLayout< 32, 32>; - using CLayout = GMMA::CLayout_64x32; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - - -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -using SM90_64x32x32_F16E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x32x32_F16E4M3E4M3_RS_TN; - -template -struct MMA_Traits> -{ - using ValTypeD = half_t; - using ValTypeA = float_e4m3_t; - using ValTypeB = float_e4m3_t; - using ValTypeC = half_t; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_32,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x32; - using BLayout = GMMA::ABLayout< 32, 32>; - using CLayout = GMMA::CLayout_64x32; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - - -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -using SM90_64x32x32_F32E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x32x32_F32E4M3E4M3_SS_TN; - -template -struct MMA_Traits> -{ - using ValTypeD = float; - using ValTypeA = float_e4m3_t; - using ValTypeB = float_e4m3_t; - using ValTypeC = float; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_32,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 32>; - using BLayout = GMMA::ABLayout< 32, 32>; - using CLayout = GMMA::CLayout_64x32; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - - -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -using SM90_64x32x32_F32E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x32x32_F32E4M3E4M3_RS_TN; - -template -struct MMA_Traits> -{ - using ValTypeD = float; - using ValTypeA = float_e4m3_t; - using ValTypeB = float_e4m3_t; - using ValTypeC = float; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_32,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x32; - using BLayout = GMMA::ABLayout< 32, 32>; - using CLayout = GMMA::CLayout_64x32; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) - -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -using SM90_64x48x32_F16E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x48x32_F16E4M3E4M3_SS_TN; - -template -struct MMA_Traits> -{ - using ValTypeD = half_t; - using ValTypeA = float_e4m3_t; - using ValTypeB = float_e4m3_t; - using ValTypeC = half_t; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_48,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 32>; - using BLayout = GMMA::ABLayout< 48, 32>; - using CLayout = GMMA::CLayout_64x48; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) - -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -using SM90_64x48x32_F16E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x48x32_F16E4M3E4M3_RS_TN; - -template -struct MMA_Traits> -{ - using ValTypeD = half_t; - using ValTypeA = float_e4m3_t; - using ValTypeB = float_e4m3_t; - using ValTypeC = half_t; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_48,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x32; - using BLayout = GMMA::ABLayout< 48, 32>; - using CLayout = GMMA::CLayout_64x48; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) - -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -using SM90_64x48x32_F32E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x48x32_F32E4M3E4M3_SS_TN; - -template -struct MMA_Traits> -{ - using ValTypeD = float; - using ValTypeA = float_e4m3_t; - using ValTypeB = float_e4m3_t; - using ValTypeC = float; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_48,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 32>; - using BLayout = GMMA::ABLayout< 48, 32>; - using CLayout = GMMA::CLayout_64x48; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) - -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -using SM90_64x48x32_F32E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x48x32_F32E4M3E4M3_RS_TN; - -template -struct MMA_Traits> -{ - using ValTypeD = float; - using ValTypeA = float_e4m3_t; - using ValTypeB = float_e4m3_t; - using ValTypeC = float; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_48,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x32; - using BLayout = GMMA::ABLayout< 48, 32>; - using CLayout = GMMA::CLayout_64x48; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - - -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -using SM90_64x64x32_F16E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x64x32_F16E4M3E4M3_SS_TN; - -template -struct MMA_Traits> -{ - using ValTypeD = half_t; - using ValTypeA = float_e4m3_t; - using ValTypeB = float_e4m3_t; - using ValTypeC = half_t; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_64,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 32>; - using BLayout = GMMA::ABLayout< 64, 32>; - using CLayout = GMMA::CLayout_64x64; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - - -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -using SM90_64x64x32_F16E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x64x32_F16E4M3E4M3_RS_TN; - -template -struct MMA_Traits> -{ - using ValTypeD = half_t; - using ValTypeA = float_e4m3_t; - using ValTypeB = float_e4m3_t; - using ValTypeC = half_t; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_64,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x32; - using BLayout = GMMA::ABLayout< 64, 32>; - using CLayout = GMMA::CLayout_64x64; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - - -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -using SM90_64x64x32_F32E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x64x32_F32E4M3E4M3_SS_TN; +using SM90_64x16x32_S32S8S8_SS_TN = SM90::GMMA::MMA_64x16x32_S32S8S8_SS_TN; -template -struct MMA_Traits> +template <> +struct MMA_Traits { - using ValTypeD = float; - using ValTypeA = float_e4m3_t; - using ValTypeB = float_e4m3_t; - using ValTypeC = float; + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_64,_32>; + using Shape_MNK = Shape<_64,_16,_32>; using ThrID = Layout<_128>; using ALayout = GMMA::ABLayout< 64, 32>; - using BLayout = GMMA::ABLayout< 64, 32>; - using CLayout = GMMA::CLayout_64x64; + using BLayout = GMMA::ABLayout< 16, 32>; + using CLayout = GMMA::CLayout_64x16; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; @@ -12583,169 +2412,140 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -using SM90_64x64x32_F32E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x64x32_F32E4M3E4M3_RS_TN; +using SM90_64x16x32_S32S8S8_SS_TN_SATURATE = SM90::GMMA::MMA_64x16x32_S32S8S8_SS_TN_SATURATE; -template -struct MMA_Traits> +template <> +struct MMA_Traits { - using ValTypeD = float; - using ValTypeA = float_e4m3_t; - using ValTypeB = float_e4m3_t; - using ValTypeC = float; + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_64,_32>; + using Shape_MNK = Shape<_64,_16,_32>; using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x32; - using BLayout = GMMA::ABLayout< 64, 32>; - using CLayout = GMMA::CLayout_64x64; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 16, 32>; + using CLayout = GMMA::CLayout_64x16; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -using SM90_64x80x32_F16E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x80x32_F16E4M3E4M3_SS_TN; +using SM90_64x32x32_S32S8S8_SS_TN = SM90::GMMA::MMA_64x32x32_S32S8S8_SS_TN; -template -struct MMA_Traits> +template <> +struct MMA_Traits { - using ValTypeD = half_t; - using ValTypeA = float_e4m3_t; - using ValTypeB = float_e4m3_t; - using ValTypeC = half_t; + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_80,_32>; + using Shape_MNK = Shape<_64,_32,_32>; using ThrID = Layout<_128>; using ALayout = GMMA::ABLayout< 64, 32>; - using BLayout = GMMA::ABLayout< 80, 32>; - using CLayout = GMMA::CLayout_64x80; + using BLayout = GMMA::ABLayout< 32, 32>; + using CLayout = GMMA::CLayout_64x32; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -using SM90_64x80x32_F16E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x80x32_F16E4M3E4M3_RS_TN; +using SM90_64x32x32_S32S8S8_SS_TN_SATURATE = SM90::GMMA::MMA_64x32x32_S32S8S8_SS_TN_SATURATE; -template -struct MMA_Traits> +template <> +struct MMA_Traits { - using ValTypeD = half_t; - using ValTypeA = float_e4m3_t; - using ValTypeB = float_e4m3_t; - using ValTypeC = half_t; + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_80,_32>; + using Shape_MNK = Shape<_64,_32,_32>; using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x32; - using BLayout = GMMA::ABLayout< 80, 32>; - using CLayout = GMMA::CLayout_64x80; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 32, 32>; + using CLayout = GMMA::CLayout_64x32; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -using SM90_64x80x32_F32E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x80x32_F32E4M3E4M3_SS_TN; +using SM90_64x64x32_S32S8S8_SS_TN = SM90::GMMA::MMA_64x64x32_S32S8S8_SS_TN; -template -struct MMA_Traits> +template <> +struct MMA_Traits { - using ValTypeD = float; - using ValTypeA = float_e4m3_t; - using ValTypeB = float_e4m3_t; - using ValTypeC = float; + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_80,_32>; + using Shape_MNK = Shape<_64,_64,_32>; using ThrID = Layout<_128>; using ALayout = GMMA::ABLayout< 64, 32>; - using BLayout = GMMA::ABLayout< 80, 32>; - using CLayout = GMMA::CLayout_64x80; + using BLayout = GMMA::ABLayout< 64, 32>; + using CLayout = GMMA::CLayout_64x64; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -using SM90_64x80x32_F32E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x80x32_F32E4M3E4M3_RS_TN; +using SM90_64x64x32_S32S8S8_SS_TN_SATURATE = SM90::GMMA::MMA_64x64x32_S32S8S8_SS_TN_SATURATE; -template -struct MMA_Traits> +template <> +struct MMA_Traits { - using ValTypeD = float; - using ValTypeA = float_e4m3_t; - using ValTypeB = float_e4m3_t; - using ValTypeC = float; + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_80,_32>; + using Shape_MNK = Shape<_64,_64,_32>; using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x32; - using BLayout = GMMA::ABLayout< 80, 32>; - using CLayout = GMMA::CLayout_64x80; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 64, 32>; + using CLayout = GMMA::CLayout_64x64; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -using SM90_64x96x32_F16E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x96x32_F16E4M3E4M3_SS_TN; +using SM90_64x96x32_S32S8S8_SS_TN = SM90::GMMA::MMA_64x96x32_S32S8S8_SS_TN; -template -struct MMA_Traits> +template <> +struct MMA_Traits { - using ValTypeD = half_t; - using ValTypeA = float_e4m3_t; - using ValTypeB = float_e4m3_t; - using ValTypeC = half_t; + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; @@ -12762,25 +2562,22 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -using SM90_64x96x32_F16E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x96x32_F16E4M3E4M3_RS_TN; +using SM90_64x96x32_S32S8S8_SS_TN_SATURATE = SM90::GMMA::MMA_64x96x32_S32S8S8_SS_TN_SATURATE; -template -struct MMA_Traits> +template <> +struct MMA_Traits { - using ValTypeD = half_t; - using ValTypeA = float_e4m3_t; - using ValTypeB = float_e4m3_t; - using ValTypeC = half_t; + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; using Shape_MNK = Shape<_64,_96,_32>; using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x32; + using ALayout = GMMA::ABLayout< 64, 32>; using BLayout = GMMA::ABLayout< 96, 32>; using CLayout = GMMA::CLayout_64x96; @@ -12790,28 +2587,24 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -using SM90_64x96x32_F32E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x96x32_F32E4M3E4M3_SS_TN; +using SM90_64x128x32_S32S8S8_SS_TN = SM90::GMMA::MMA_64x128x32_S32S8S8_SS_TN; -template -struct MMA_Traits> +template <> +struct MMA_Traits { - using ValTypeD = float; - using ValTypeA = float_e4m3_t; - using ValTypeB = float_e4m3_t; - using ValTypeC = float; + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_96,_32>; + using Shape_MNK = Shape<_64,_128,_32>; using ThrID = Layout<_128>; using ALayout = GMMA::ABLayout< 64, 32>; - using BLayout = GMMA::ABLayout< 96, 32>; - using CLayout = GMMA::CLayout_64x96; + using BLayout = GMMA::ABLayout<128, 32>; + using CLayout = GMMA::CLayout_64x128; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; @@ -12819,178 +2612,148 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -using SM90_64x96x32_F32E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x96x32_F32E4M3E4M3_RS_TN; +using SM90_64x128x32_S32S8S8_SS_TN_SATURATE = SM90::GMMA::MMA_64x128x32_S32S8S8_SS_TN_SATURATE; -template -struct MMA_Traits> +template <> +struct MMA_Traits { - using ValTypeD = float; - using ValTypeA = float_e4m3_t; - using ValTypeB = float_e4m3_t; - using ValTypeC = float; + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_96,_32>; + using Shape_MNK = Shape<_64,_128,_32>; using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x32; - using BLayout = GMMA::ABLayout< 96, 32>; - using CLayout = GMMA::CLayout_64x96; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<128, 32>; + using CLayout = GMMA::CLayout_64x128; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -using SM90_64x112x32_F16E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x112x32_F16E4M3E4M3_SS_TN; +using SM90_64x192x32_S32S8S8_SS_TN = SM90::GMMA::MMA_64x192x32_S32S8S8_SS_TN; -template -struct MMA_Traits> +template <> +struct MMA_Traits { - using ValTypeD = half_t; - using ValTypeA = float_e4m3_t; - using ValTypeB = float_e4m3_t; - using ValTypeC = half_t; + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_112,_32>; + using Shape_MNK = Shape<_64,_192,_32>; using ThrID = Layout<_128>; using ALayout = GMMA::ABLayout< 64, 32>; - using BLayout = GMMA::ABLayout<112, 32>; - using CLayout = GMMA::CLayout_64x112; + using BLayout = GMMA::ABLayout<192, 32>; + using CLayout = GMMA::CLayout_64x192; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -using SM90_64x112x32_F16E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x112x32_F16E4M3E4M3_RS_TN; +using SM90_64x192x32_S32S8S8_SS_TN_SATURATE = SM90::GMMA::MMA_64x192x32_S32S8S8_SS_TN_SATURATE; -template -struct MMA_Traits> +template <> +struct MMA_Traits { - using ValTypeD = half_t; - using ValTypeA = float_e4m3_t; - using ValTypeB = float_e4m3_t; - using ValTypeC = half_t; + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_112,_32>; + using Shape_MNK = Shape<_64,_192,_32>; using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x32; - using BLayout = GMMA::ABLayout<112, 32>; - using CLayout = GMMA::CLayout_64x112; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<192, 32>; + using CLayout = GMMA::CLayout_64x192; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -using SM90_64x112x32_F32E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x112x32_F32E4M3E4M3_SS_TN; +using SM90_64x256x32_S32S8S8_SS_TN = SM90::GMMA::MMA_64x256x32_S32S8S8_SS_TN; -template -struct MMA_Traits> +template <> +struct MMA_Traits { - using ValTypeD = float; - using ValTypeA = float_e4m3_t; - using ValTypeB = float_e4m3_t; - using ValTypeC = float; + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_112,_32>; + using Shape_MNK = Shape<_64,_256,_32>; using ThrID = Layout<_128>; using ALayout = GMMA::ABLayout< 64, 32>; - using BLayout = GMMA::ABLayout<112, 32>; - using CLayout = GMMA::CLayout_64x112; + using BLayout = GMMA::ABLayout<256, 32>; + using CLayout = GMMA::CLayout_64x256; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -using SM90_64x112x32_F32E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x112x32_F32E4M3E4M3_RS_TN; +using SM90_64x256x32_S32S8S8_SS_TN_SATURATE = SM90::GMMA::MMA_64x256x32_S32S8S8_SS_TN_SATURATE; -template -struct MMA_Traits> +template <> +struct MMA_Traits { - using ValTypeD = float; - using ValTypeA = float_e4m3_t; - using ValTypeB = float_e4m3_t; - using ValTypeC = float; + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_112,_32>; + using Shape_MNK = Shape<_64,_256,_32>; using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x32; - using BLayout = GMMA::ABLayout<112, 32>; - using CLayout = GMMA::CLayout_64x112; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<256, 32>; + using CLayout = GMMA::CLayout_64x256; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -using SM90_64x128x32_F16E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x128x32_F16E4M3E4M3_SS_TN; +using SM90_64x8x32_S32S8S8_RS_TN = SM90::GMMA::MMA_64x8x32_S32S8S8_RS_TN; -template -struct MMA_Traits> +template <> +struct MMA_Traits { - using ValTypeD = half_t; - using ValTypeA = float_e4m3_t; - using ValTypeB = float_e4m3_t; - using ValTypeC = half_t; + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; - using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_128,_32>; + using Shape_MNK = Shape<_64,_8,_32>; using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 32>; - using BLayout = GMMA::ABLayout<128, 32>; - using CLayout = GMMA::CLayout_64x128; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 8, 32>; + using CLayout = GMMA::CLayout_64x8; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; @@ -12998,27 +2761,23 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -using SM90_64x128x32_F16E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x128x32_F16E4M3E4M3_RS_TN; +using SM90_64x8x32_S32S8S8_RS_TN_SATURATE = SM90::GMMA::MMA_64x8x32_S32S8S8_RS_TN_SATURATE; -template -struct MMA_Traits> +template <> +struct MMA_Traits { - using ValTypeD = half_t; - using ValTypeA = float_e4m3_t; - using ValTypeB = float_e4m3_t; - using ValTypeC = half_t; + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_128,_32>; + using Shape_MNK = Shape<_64,_8,_32>; using ThrID = Layout<_128>; using ALayout = GMMA::ALayout_64x32; - using BLayout = GMMA::ABLayout<128, 32>; - using CLayout = GMMA::CLayout_64x128; + using BLayout = GMMA::ABLayout< 8, 32>; + using CLayout = GMMA::CLayout_64x8; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; @@ -13026,28 +2785,23 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -using SM90_64x128x32_F32E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x128x32_F32E4M3E4M3_SS_TN; +using SM90_64x16x32_S32S8S8_RS_TN = SM90::GMMA::MMA_64x16x32_S32S8S8_RS_TN; -template -struct MMA_Traits> +template <> +struct MMA_Traits { - using ValTypeD = float; - using ValTypeA = float_e4m3_t; - using ValTypeB = float_e4m3_t; - using ValTypeC = float; + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; - using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_128,_32>; + using Shape_MNK = Shape<_64,_16,_32>; using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 32>; - using BLayout = GMMA::ABLayout<128, 32>; - using CLayout = GMMA::CLayout_64x128; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 16, 32>; + using CLayout = GMMA::CLayout_64x16; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; @@ -13055,422 +2809,336 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -using SM90_64x128x32_F32E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x128x32_F32E4M3E4M3_RS_TN; +using SM90_64x16x32_S32S8S8_RS_TN_SATURATE = SM90::GMMA::MMA_64x16x32_S32S8S8_RS_TN_SATURATE; -template -struct MMA_Traits> +template <> +struct MMA_Traits { - using ValTypeD = float; - using ValTypeA = float_e4m3_t; - using ValTypeB = float_e4m3_t; - using ValTypeC = float; + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_128,_32>; + using Shape_MNK = Shape<_64,_16,_32>; using ThrID = Layout<_128>; using ALayout = GMMA::ALayout_64x32; - using BLayout = GMMA::ABLayout<128, 32>; - using CLayout = GMMA::CLayout_64x128; + using BLayout = GMMA::ABLayout< 16, 32>; + using CLayout = GMMA::CLayout_64x16; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -using SM90_64x144x32_F16E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x144x32_F16E4M3E4M3_SS_TN; +using SM90_64x32x32_S32S8S8_RS_TN = SM90::GMMA::MMA_64x32x32_S32S8S8_RS_TN; -template -struct MMA_Traits> +template <> +struct MMA_Traits { - using ValTypeD = half_t; - using ValTypeA = float_e4m3_t; - using ValTypeB = float_e4m3_t; - using ValTypeC = half_t; + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; - using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_144,_32>; + using Shape_MNK = Shape<_64,_32,_32>; using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 32>; - using BLayout = GMMA::ABLayout<144, 32>; - using CLayout = GMMA::CLayout_64x144; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 32, 32>; + using CLayout = GMMA::CLayout_64x32; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -using SM90_64x144x32_F16E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x144x32_F16E4M3E4M3_RS_TN; +using SM90_64x32x32_S32S8S8_RS_TN_SATURATE = SM90::GMMA::MMA_64x32x32_S32S8S8_RS_TN_SATURATE; -template -struct MMA_Traits> +template <> +struct MMA_Traits { - using ValTypeD = half_t; - using ValTypeA = float_e4m3_t; - using ValTypeB = float_e4m3_t; - using ValTypeC = half_t; + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_144,_32>; + using Shape_MNK = Shape<_64,_32,_32>; using ThrID = Layout<_128>; using ALayout = GMMA::ALayout_64x32; - using BLayout = GMMA::ABLayout<144, 32>; - using CLayout = GMMA::CLayout_64x144; + using BLayout = GMMA::ABLayout< 32, 32>; + using CLayout = GMMA::CLayout_64x32; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -using SM90_64x144x32_F32E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x144x32_F32E4M3E4M3_SS_TN; +using SM90_64x64x32_S32S8S8_RS_TN = SM90::GMMA::MMA_64x64x32_S32S8S8_RS_TN; -template -struct MMA_Traits> +template <> +struct MMA_Traits { - using ValTypeD = float; - using ValTypeA = float_e4m3_t; - using ValTypeB = float_e4m3_t; - using ValTypeC = float; + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; - using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_144,_32>; + using Shape_MNK = Shape<_64,_64,_32>; using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 32>; - using BLayout = GMMA::ABLayout<144, 32>; - using CLayout = GMMA::CLayout_64x144; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 64, 32>; + using CLayout = GMMA::CLayout_64x64; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -using SM90_64x144x32_F32E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x144x32_F32E4M3E4M3_RS_TN; +using SM90_64x64x32_S32S8S8_RS_TN_SATURATE = SM90::GMMA::MMA_64x64x32_S32S8S8_RS_TN_SATURATE; -template -struct MMA_Traits> +template <> +struct MMA_Traits { - using ValTypeD = float; - using ValTypeA = float_e4m3_t; - using ValTypeB = float_e4m3_t; - using ValTypeC = float; + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_144,_32>; + using Shape_MNK = Shape<_64,_64,_32>; using ThrID = Layout<_128>; using ALayout = GMMA::ALayout_64x32; - using BLayout = GMMA::ABLayout<144, 32>; - using CLayout = GMMA::CLayout_64x144; + using BLayout = GMMA::ABLayout< 64, 32>; + using CLayout = GMMA::CLayout_64x64; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -using SM90_64x160x32_F16E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x160x32_F16E4M3E4M3_SS_TN; +using SM90_64x96x32_S32S8S8_RS_TN = SM90::GMMA::MMA_64x96x32_S32S8S8_RS_TN; -template -struct MMA_Traits> +template <> +struct MMA_Traits { - using ValTypeD = half_t; - using ValTypeA = float_e4m3_t; - using ValTypeB = float_e4m3_t; - using ValTypeC = half_t; + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; - using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_160,_32>; + using Shape_MNK = Shape<_64,_96,_32>; using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 32>; - using BLayout = GMMA::ABLayout<160, 32>; - using CLayout = GMMA::CLayout_64x160; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 96, 32>; + using CLayout = GMMA::CLayout_64x96; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -using SM90_64x160x32_F16E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x160x32_F16E4M3E4M3_RS_TN; +using SM90_64x96x32_S32S8S8_RS_TN_SATURATE = SM90::GMMA::MMA_64x96x32_S32S8S8_RS_TN_SATURATE; -template -struct MMA_Traits> +template <> +struct MMA_Traits { - using ValTypeD = half_t; - using ValTypeA = float_e4m3_t; - using ValTypeB = float_e4m3_t; - using ValTypeC = half_t; + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_160,_32>; + using Shape_MNK = Shape<_64,_96,_32>; using ThrID = Layout<_128>; using ALayout = GMMA::ALayout_64x32; - using BLayout = GMMA::ABLayout<160, 32>; - using CLayout = GMMA::CLayout_64x160; + using BLayout = GMMA::ABLayout< 96, 32>; + using CLayout = GMMA::CLayout_64x96; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -using SM90_64x160x32_F32E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x160x32_F32E4M3E4M3_SS_TN; +using SM90_64x128x32_S32S8S8_RS_TN = SM90::GMMA::MMA_64x128x32_S32S8S8_RS_TN; -template -struct MMA_Traits> +template <> +struct MMA_Traits { - using ValTypeD = float; - using ValTypeA = float_e4m3_t; - using ValTypeB = float_e4m3_t; - using ValTypeC = float; + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; - using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_160,_32>; + using Shape_MNK = Shape<_64,_128,_32>; using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 32>; - using BLayout = GMMA::ABLayout<160, 32>; - using CLayout = GMMA::CLayout_64x160; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<128, 32>; + using CLayout = GMMA::CLayout_64x128; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -using SM90_64x160x32_F32E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x160x32_F32E4M3E4M3_RS_TN; +using SM90_64x128x32_S32S8S8_RS_TN_SATURATE = SM90::GMMA::MMA_64x128x32_S32S8S8_RS_TN_SATURATE; -template -struct MMA_Traits> +template <> +struct MMA_Traits { - using ValTypeD = float; - using ValTypeA = float_e4m3_t; - using ValTypeB = float_e4m3_t; - using ValTypeC = float; + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_160,_32>; + using Shape_MNK = Shape<_64,_128,_32>; using ThrID = Layout<_128>; using ALayout = GMMA::ALayout_64x32; - using BLayout = GMMA::ABLayout<160, 32>; - using CLayout = GMMA::CLayout_64x160; + using BLayout = GMMA::ABLayout<128, 32>; + using CLayout = GMMA::CLayout_64x128; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -using SM90_64x176x32_F16E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x176x32_F16E4M3E4M3_SS_TN; +using SM90_64x192x32_S32S8S8_RS_TN = SM90::GMMA::MMA_64x192x32_S32S8S8_RS_TN; -template -struct MMA_Traits> +template <> +struct MMA_Traits { - using ValTypeD = half_t; - using ValTypeA = float_e4m3_t; - using ValTypeB = float_e4m3_t; - using ValTypeC = half_t; + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; - using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_176,_32>; + using Shape_MNK = Shape<_64,_192,_32>; using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 32>; - using BLayout = GMMA::ABLayout<176, 32>; - using CLayout = GMMA::CLayout_64x176; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<192, 32>; + using CLayout = GMMA::CLayout_64x192; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -using SM90_64x176x32_F16E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x176x32_F16E4M3E4M3_RS_TN; +using SM90_64x192x32_S32S8S8_RS_TN_SATURATE = SM90::GMMA::MMA_64x192x32_S32S8S8_RS_TN_SATURATE; -template -struct MMA_Traits> +template <> +struct MMA_Traits { - using ValTypeD = half_t; - using ValTypeA = float_e4m3_t; - using ValTypeB = float_e4m3_t; - using ValTypeC = half_t; + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_176,_32>; + using Shape_MNK = Shape<_64,_192,_32>; using ThrID = Layout<_128>; using ALayout = GMMA::ALayout_64x32; - using BLayout = GMMA::ABLayout<176, 32>; - using CLayout = GMMA::CLayout_64x176; + using BLayout = GMMA::ABLayout<192, 32>; + using CLayout = GMMA::CLayout_64x192; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -using SM90_64x176x32_F32E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x176x32_F32E4M3E4M3_SS_TN; +using SM90_64x256x32_S32S8S8_RS_TN = SM90::GMMA::MMA_64x256x32_S32S8S8_RS_TN; -template -struct MMA_Traits> +template <> +struct MMA_Traits { - using ValTypeD = float; - using ValTypeA = float_e4m3_t; - using ValTypeB = float_e4m3_t; - using ValTypeC = float; + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; - using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_176,_32>; + using Shape_MNK = Shape<_64,_256,_32>; using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 32>; - using BLayout = GMMA::ABLayout<176, 32>; - using CLayout = GMMA::CLayout_64x176; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<256, 32>; + using CLayout = GMMA::CLayout_64x256; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -using SM90_64x176x32_F32E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x176x32_F32E4M3E4M3_RS_TN; +using SM90_64x256x32_S32S8S8_RS_TN_SATURATE = SM90::GMMA::MMA_64x256x32_S32S8S8_RS_TN_SATURATE; -template -struct MMA_Traits> +template <> +struct MMA_Traits { - using ValTypeD = float; - using ValTypeA = float_e4m3_t; - using ValTypeB = float_e4m3_t; - using ValTypeC = float; + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_176,_32>; + using Shape_MNK = Shape<_64,_256,_32>; using ThrID = Layout<_128>; using ALayout = GMMA::ALayout_64x32; - using BLayout = GMMA::ABLayout<176, 32>; - using CLayout = GMMA::CLayout_64x176; + using BLayout = GMMA::ABLayout<256, 32>; + using CLayout = GMMA::CLayout_64x256; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -using SM90_64x192x32_F16E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x192x32_F16E4M3E4M3_SS_TN; +using SM90_64x8x32_S32S8U8_SS_TN = SM90::GMMA::MMA_64x8x32_S32S8U8_SS_TN; -template -struct MMA_Traits> +template <> +struct MMA_Traits { - using ValTypeD = half_t; - using ValTypeA = float_e4m3_t; - using ValTypeB = float_e4m3_t; - using ValTypeC = half_t; + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_192,_32>; + using Shape_MNK = Shape<_64,_8,_32>; using ThrID = Layout<_128>; using ALayout = GMMA::ABLayout< 64, 32>; - using BLayout = GMMA::ABLayout<192, 32>; - using CLayout = GMMA::CLayout_64x192; + using BLayout = GMMA::ABLayout< 8, 32>; + using CLayout = GMMA::CLayout_64x8; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; @@ -13478,27 +3146,24 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -using SM90_64x192x32_F16E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x192x32_F16E4M3E4M3_RS_TN; +using SM90_64x8x32_S32S8U8_SS_TN_SATURATE = SM90::GMMA::MMA_64x8x32_S32S8U8_SS_TN_SATURATE; -template -struct MMA_Traits> +template <> +struct MMA_Traits { - using ValTypeD = half_t; - using ValTypeA = float_e4m3_t; - using ValTypeB = float_e4m3_t; - using ValTypeC = half_t; + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_192,_32>; + using Shape_MNK = Shape<_64,_8,_32>; using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x32; - using BLayout = GMMA::ABLayout<192, 32>; - using CLayout = GMMA::CLayout_64x192; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 8, 32>; + using CLayout = GMMA::CLayout_64x8; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; @@ -13506,28 +3171,24 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -using SM90_64x192x32_F32E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x192x32_F32E4M3E4M3_SS_TN; +using SM90_64x16x32_S32S8U8_SS_TN = SM90::GMMA::MMA_64x16x32_S32S8U8_SS_TN; -template -struct MMA_Traits> +template <> +struct MMA_Traits { - using ValTypeD = float; - using ValTypeA = float_e4m3_t; - using ValTypeB = float_e4m3_t; - using ValTypeC = float; + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_192,_32>; + using Shape_MNK = Shape<_64,_16,_32>; using ThrID = Layout<_128>; using ALayout = GMMA::ABLayout< 64, 32>; - using BLayout = GMMA::ABLayout<192, 32>; - using CLayout = GMMA::CLayout_64x192; + using BLayout = GMMA::ABLayout< 16, 32>; + using CLayout = GMMA::CLayout_64x16; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; @@ -13535,422 +3196,348 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -using SM90_64x192x32_F32E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x192x32_F32E4M3E4M3_RS_TN; +using SM90_64x16x32_S32S8U8_SS_TN_SATURATE = SM90::GMMA::MMA_64x16x32_S32S8U8_SS_TN_SATURATE; -template -struct MMA_Traits> +template <> +struct MMA_Traits { - using ValTypeD = float; - using ValTypeA = float_e4m3_t; - using ValTypeB = float_e4m3_t; - using ValTypeC = float; + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_192,_32>; + using Shape_MNK = Shape<_64,_16,_32>; using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x32; - using BLayout = GMMA::ABLayout<192, 32>; - using CLayout = GMMA::CLayout_64x192; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 16, 32>; + using CLayout = GMMA::CLayout_64x16; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -using SM90_64x208x32_F16E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x208x32_F16E4M3E4M3_SS_TN; +using SM90_64x32x32_S32S8U8_SS_TN = SM90::GMMA::MMA_64x32x32_S32S8U8_SS_TN; -template -struct MMA_Traits> +template <> +struct MMA_Traits { - using ValTypeD = half_t; - using ValTypeA = float_e4m3_t; - using ValTypeB = float_e4m3_t; - using ValTypeC = half_t; + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_208,_32>; + using Shape_MNK = Shape<_64,_32,_32>; using ThrID = Layout<_128>; using ALayout = GMMA::ABLayout< 64, 32>; - using BLayout = GMMA::ABLayout<208, 32>; - using CLayout = GMMA::CLayout_64x208; + using BLayout = GMMA::ABLayout< 32, 32>; + using CLayout = GMMA::CLayout_64x32; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -using SM90_64x208x32_F16E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x208x32_F16E4M3E4M3_RS_TN; +using SM90_64x32x32_S32S8U8_SS_TN_SATURATE = SM90::GMMA::MMA_64x32x32_S32S8U8_SS_TN_SATURATE; -template -struct MMA_Traits> +template <> +struct MMA_Traits { - using ValTypeD = half_t; - using ValTypeA = float_e4m3_t; - using ValTypeB = float_e4m3_t; - using ValTypeC = half_t; + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_208,_32>; + using Shape_MNK = Shape<_64,_32,_32>; using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x32; - using BLayout = GMMA::ABLayout<208, 32>; - using CLayout = GMMA::CLayout_64x208; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 32, 32>; + using CLayout = GMMA::CLayout_64x32; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -using SM90_64x208x32_F32E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x208x32_F32E4M3E4M3_SS_TN; +using SM90_64x64x32_S32S8U8_SS_TN = SM90::GMMA::MMA_64x64x32_S32S8U8_SS_TN; -template -struct MMA_Traits> +template <> +struct MMA_Traits { - using ValTypeD = float; - using ValTypeA = float_e4m3_t; - using ValTypeB = float_e4m3_t; - using ValTypeC = float; + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_208,_32>; + using Shape_MNK = Shape<_64,_64,_32>; using ThrID = Layout<_128>; using ALayout = GMMA::ABLayout< 64, 32>; - using BLayout = GMMA::ABLayout<208, 32>; - using CLayout = GMMA::CLayout_64x208; + using BLayout = GMMA::ABLayout< 64, 32>; + using CLayout = GMMA::CLayout_64x64; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -using SM90_64x208x32_F32E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x208x32_F32E4M3E4M3_RS_TN; +using SM90_64x64x32_S32S8U8_SS_TN_SATURATE = SM90::GMMA::MMA_64x64x32_S32S8U8_SS_TN_SATURATE; -template -struct MMA_Traits> +template <> +struct MMA_Traits { - using ValTypeD = float; - using ValTypeA = float_e4m3_t; - using ValTypeB = float_e4m3_t; - using ValTypeC = float; + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_208,_32>; + using Shape_MNK = Shape<_64,_64,_32>; using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x32; - using BLayout = GMMA::ABLayout<208, 32>; - using CLayout = GMMA::CLayout_64x208; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 64, 32>; + using CLayout = GMMA::CLayout_64x64; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -using SM90_64x224x32_F16E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x224x32_F16E4M3E4M3_SS_TN; +using SM90_64x96x32_S32S8U8_SS_TN = SM90::GMMA::MMA_64x96x32_S32S8U8_SS_TN; -template -struct MMA_Traits> +template <> +struct MMA_Traits { - using ValTypeD = half_t; - using ValTypeA = float_e4m3_t; - using ValTypeB = float_e4m3_t; - using ValTypeC = half_t; + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_224,_32>; + using Shape_MNK = Shape<_64,_96,_32>; using ThrID = Layout<_128>; using ALayout = GMMA::ABLayout< 64, 32>; - using BLayout = GMMA::ABLayout<224, 32>; - using CLayout = GMMA::CLayout_64x224; + using BLayout = GMMA::ABLayout< 96, 32>; + using CLayout = GMMA::CLayout_64x96; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -using SM90_64x224x32_F16E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x224x32_F16E4M3E4M3_RS_TN; +using SM90_64x96x32_S32S8U8_SS_TN_SATURATE = SM90::GMMA::MMA_64x96x32_S32S8U8_SS_TN_SATURATE; -template -struct MMA_Traits> +template <> +struct MMA_Traits { - using ValTypeD = half_t; - using ValTypeA = float_e4m3_t; - using ValTypeB = float_e4m3_t; - using ValTypeC = half_t; + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_224,_32>; + using Shape_MNK = Shape<_64,_96,_32>; using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x32; - using BLayout = GMMA::ABLayout<224, 32>; - using CLayout = GMMA::CLayout_64x224; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 96, 32>; + using CLayout = GMMA::CLayout_64x96; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -using SM90_64x224x32_F32E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x224x32_F32E4M3E4M3_SS_TN; +using SM90_64x128x32_S32S8U8_SS_TN = SM90::GMMA::MMA_64x128x32_S32S8U8_SS_TN; -template -struct MMA_Traits> +template <> +struct MMA_Traits { - using ValTypeD = float; - using ValTypeA = float_e4m3_t; - using ValTypeB = float_e4m3_t; - using ValTypeC = float; + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_224,_32>; + using Shape_MNK = Shape<_64,_128,_32>; using ThrID = Layout<_128>; using ALayout = GMMA::ABLayout< 64, 32>; - using BLayout = GMMA::ABLayout<224, 32>; - using CLayout = GMMA::CLayout_64x224; + using BLayout = GMMA::ABLayout<128, 32>; + using CLayout = GMMA::CLayout_64x128; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -using SM90_64x224x32_F32E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x224x32_F32E4M3E4M3_RS_TN; +using SM90_64x128x32_S32S8U8_SS_TN_SATURATE = SM90::GMMA::MMA_64x128x32_S32S8U8_SS_TN_SATURATE; -template -struct MMA_Traits> +template <> +struct MMA_Traits { - using ValTypeD = float; - using ValTypeA = float_e4m3_t; - using ValTypeB = float_e4m3_t; - using ValTypeC = float; + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_224,_32>; + using Shape_MNK = Shape<_64,_128,_32>; using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x32; - using BLayout = GMMA::ABLayout<224, 32>; - using CLayout = GMMA::CLayout_64x224; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<128, 32>; + using CLayout = GMMA::CLayout_64x128; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -using SM90_64x240x32_F16E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x240x32_F16E4M3E4M3_SS_TN; +using SM90_64x192x32_S32S8U8_SS_TN = SM90::GMMA::MMA_64x192x32_S32S8U8_SS_TN; -template -struct MMA_Traits> +template <> +struct MMA_Traits { - using ValTypeD = half_t; - using ValTypeA = float_e4m3_t; - using ValTypeB = float_e4m3_t; - using ValTypeC = half_t; + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_240,_32>; + using Shape_MNK = Shape<_64,_192,_32>; using ThrID = Layout<_128>; using ALayout = GMMA::ABLayout< 64, 32>; - using BLayout = GMMA::ABLayout<240, 32>; - using CLayout = GMMA::CLayout_64x240; + using BLayout = GMMA::ABLayout<192, 32>; + using CLayout = GMMA::CLayout_64x192; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -using SM90_64x240x32_F16E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x240x32_F16E4M3E4M3_RS_TN; +using SM90_64x192x32_S32S8U8_SS_TN_SATURATE = SM90::GMMA::MMA_64x192x32_S32S8U8_SS_TN_SATURATE; -template -struct MMA_Traits> +template <> +struct MMA_Traits { - using ValTypeD = half_t; - using ValTypeA = float_e4m3_t; - using ValTypeB = float_e4m3_t; - using ValTypeC = half_t; + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_240,_32>; + using Shape_MNK = Shape<_64,_192,_32>; using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x32; - using BLayout = GMMA::ABLayout<240, 32>; - using CLayout = GMMA::CLayout_64x240; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<192, 32>; + using CLayout = GMMA::CLayout_64x192; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -using SM90_64x240x32_F32E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x240x32_F32E4M3E4M3_SS_TN; +using SM90_64x256x32_S32S8U8_SS_TN = SM90::GMMA::MMA_64x256x32_S32S8U8_SS_TN; -template -struct MMA_Traits> +template <> +struct MMA_Traits { - using ValTypeD = float; - using ValTypeA = float_e4m3_t; - using ValTypeB = float_e4m3_t; - using ValTypeC = float; + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_240,_32>; + using Shape_MNK = Shape<_64,_256,_32>; using ThrID = Layout<_128>; using ALayout = GMMA::ABLayout< 64, 32>; - using BLayout = GMMA::ABLayout<240, 32>; - using CLayout = GMMA::CLayout_64x240; + using BLayout = GMMA::ABLayout<256, 32>; + using CLayout = GMMA::CLayout_64x256; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -using SM90_64x240x32_F32E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x240x32_F32E4M3E4M3_RS_TN; +using SM90_64x256x32_S32S8U8_SS_TN_SATURATE = SM90::GMMA::MMA_64x256x32_S32S8U8_SS_TN_SATURATE; -template -struct MMA_Traits> +template <> +struct MMA_Traits { - using ValTypeD = float; - using ValTypeA = float_e4m3_t; - using ValTypeB = float_e4m3_t; - using ValTypeC = float; + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_240,_32>; + using Shape_MNK = Shape<_64,_256,_32>; using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x32; - using BLayout = GMMA::ABLayout<240, 32>; - using CLayout = GMMA::CLayout_64x240; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<256, 32>; + using CLayout = GMMA::CLayout_64x256; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -using SM90_64x256x32_F16E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x256x32_F16E4M3E4M3_SS_TN; +using SM90_64x8x32_S32S8U8_RS_TN = SM90::GMMA::MMA_64x8x32_S32S8U8_RS_TN; -template -struct MMA_Traits> +template <> +struct MMA_Traits { - using ValTypeD = half_t; - using ValTypeA = float_e4m3_t; - using ValTypeB = float_e4m3_t; - using ValTypeC = half_t; + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; - using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_256,_32>; + using Shape_MNK = Shape<_64,_8,_32>; using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 32>; - using BLayout = GMMA::ABLayout<256, 32>; - using CLayout = GMMA::CLayout_64x256; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 8, 32>; + using CLayout = GMMA::CLayout_64x8; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; @@ -13958,27 +3545,23 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -using SM90_64x256x32_F16E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x256x32_F16E4M3E4M3_RS_TN; +using SM90_64x8x32_S32S8U8_RS_TN_SATURATE = SM90::GMMA::MMA_64x8x32_S32S8U8_RS_TN_SATURATE; -template -struct MMA_Traits> +template <> +struct MMA_Traits { - using ValTypeD = half_t; - using ValTypeA = float_e4m3_t; - using ValTypeB = float_e4m3_t; - using ValTypeC = half_t; + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_256,_32>; + using Shape_MNK = Shape<_64,_8,_32>; using ThrID = Layout<_128>; using ALayout = GMMA::ALayout_64x32; - using BLayout = GMMA::ABLayout<256, 32>; - using CLayout = GMMA::CLayout_64x256; + using BLayout = GMMA::ABLayout< 8, 32>; + using CLayout = GMMA::CLayout_64x8; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; @@ -13986,28 +3569,23 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -using SM90_64x256x32_F32E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x256x32_F32E4M3E4M3_SS_TN; +using SM90_64x16x32_S32S8U8_RS_TN = SM90::GMMA::MMA_64x16x32_S32S8U8_RS_TN; -template -struct MMA_Traits> +template <> +struct MMA_Traits { - using ValTypeD = float; - using ValTypeA = float_e4m3_t; - using ValTypeB = float_e4m3_t; - using ValTypeC = float; + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; - using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_256,_32>; + using Shape_MNK = Shape<_64,_16,_32>; using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 32>; - using BLayout = GMMA::ABLayout<256, 32>; - using CLayout = GMMA::CLayout_64x256; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 16, 32>; + using CLayout = GMMA::CLayout_64x16; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; @@ -14015,27 +3593,23 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -using SM90_64x256x32_F32E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x256x32_F32E4M3E4M3_RS_TN; +using SM90_64x16x32_S32S8U8_RS_TN_SATURATE = SM90::GMMA::MMA_64x16x32_S32S8U8_RS_TN_SATURATE; -template -struct MMA_Traits> +template <> +struct MMA_Traits { - using ValTypeD = float; - using ValTypeA = float_e4m3_t; - using ValTypeB = float_e4m3_t; - using ValTypeC = float; + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_256,_32>; + using Shape_MNK = Shape<_64,_16,_32>; using ThrID = Layout<_128>; using ALayout = GMMA::ALayout_64x32; - using BLayout = GMMA::ABLayout<256, 32>; - using CLayout = GMMA::CLayout_64x256; + using BLayout = GMMA::ABLayout< 16, 32>; + using CLayout = GMMA::CLayout_64x16; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; @@ -14043,28 +3617,23 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -using SM90_64x8x32_F16E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x8x32_F16E4M3E5M2_SS_TN; +using SM90_64x32x32_S32S8U8_RS_TN = SM90::GMMA::MMA_64x32x32_S32S8U8_RS_TN; -template -struct MMA_Traits> +template <> +struct MMA_Traits { - using ValTypeD = half_t; - using ValTypeA = float_e4m3_t; - using ValTypeB = float_e5m2_t; - using ValTypeC = half_t; + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; - using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_8,_32>; + using Shape_MNK = Shape<_64,_32,_32>; using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 32>; - using BLayout = GMMA::ABLayout< 8, 32>; - using CLayout = GMMA::CLayout_64x8; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 32, 32>; + using CLayout = GMMA::CLayout_64x32; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; @@ -14072,27 +3641,23 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -using SM90_64x8x32_F16E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x8x32_F16E4M3E5M2_RS_TN; +using SM90_64x32x32_S32S8U8_RS_TN_SATURATE = SM90::GMMA::MMA_64x32x32_S32S8U8_RS_TN_SATURATE; -template -struct MMA_Traits> +template <> +struct MMA_Traits { - using ValTypeD = half_t; - using ValTypeA = float_e4m3_t; - using ValTypeB = float_e5m2_t; - using ValTypeC = half_t; + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_8,_32>; + using Shape_MNK = Shape<_64,_32,_32>; using ThrID = Layout<_128>; using ALayout = GMMA::ALayout_64x32; - using BLayout = GMMA::ABLayout< 8, 32>; - using CLayout = GMMA::CLayout_64x8; + using BLayout = GMMA::ABLayout< 32, 32>; + using CLayout = GMMA::CLayout_64x32; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; @@ -14100,28 +3665,23 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -using SM90_64x8x32_F32E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x8x32_F32E4M3E5M2_SS_TN; +using SM90_64x64x32_S32S8U8_RS_TN = SM90::GMMA::MMA_64x64x32_S32S8U8_RS_TN; -template -struct MMA_Traits> +template <> +struct MMA_Traits { - using ValTypeD = float; - using ValTypeA = float_e4m3_t; - using ValTypeB = float_e5m2_t; - using ValTypeC = float; + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; - using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_8,_32>; + using Shape_MNK = Shape<_64,_64,_32>; using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 32>; - using BLayout = GMMA::ABLayout< 8, 32>; - using CLayout = GMMA::CLayout_64x8; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 64, 32>; + using CLayout = GMMA::CLayout_64x64; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; @@ -14129,27 +3689,23 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -using SM90_64x8x32_F32E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x8x32_F32E4M3E5M2_RS_TN; +using SM90_64x64x32_S32S8U8_RS_TN_SATURATE = SM90::GMMA::MMA_64x64x32_S32S8U8_RS_TN_SATURATE; -template -struct MMA_Traits> +template <> +struct MMA_Traits { - using ValTypeD = float; - using ValTypeA = float_e4m3_t; - using ValTypeB = float_e5m2_t; - using ValTypeC = float; + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_8,_32>; + using Shape_MNK = Shape<_64,_64,_32>; using ThrID = Layout<_128>; using ALayout = GMMA::ALayout_64x32; - using BLayout = GMMA::ABLayout< 8, 32>; - using CLayout = GMMA::CLayout_64x8; + using BLayout = GMMA::ABLayout< 64, 32>; + using CLayout = GMMA::CLayout_64x64; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; @@ -14157,28 +3713,23 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -using SM90_64x16x32_F16E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x16x32_F16E4M3E5M2_SS_TN; +using SM90_64x96x32_S32S8U8_RS_TN = SM90::GMMA::MMA_64x96x32_S32S8U8_RS_TN; -template -struct MMA_Traits> +template <> +struct MMA_Traits { - using ValTypeD = half_t; - using ValTypeA = float_e4m3_t; - using ValTypeB = float_e5m2_t; - using ValTypeC = half_t; + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; - using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_16,_32>; + using Shape_MNK = Shape<_64,_96,_32>; using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 32>; - using BLayout = GMMA::ABLayout< 16, 32>; - using CLayout = GMMA::CLayout_64x16; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 96, 32>; + using CLayout = GMMA::CLayout_64x96; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; @@ -14186,27 +3737,23 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -using SM90_64x16x32_F16E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x16x32_F16E4M3E5M2_RS_TN; +using SM90_64x96x32_S32S8U8_RS_TN_SATURATE = SM90::GMMA::MMA_64x96x32_S32S8U8_RS_TN_SATURATE; -template -struct MMA_Traits> +template <> +struct MMA_Traits { - using ValTypeD = half_t; - using ValTypeA = float_e4m3_t; - using ValTypeB = float_e5m2_t; - using ValTypeC = half_t; + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_16,_32>; + using Shape_MNK = Shape<_64,_96,_32>; using ThrID = Layout<_128>; using ALayout = GMMA::ALayout_64x32; - using BLayout = GMMA::ABLayout< 16, 32>; - using CLayout = GMMA::CLayout_64x16; + using BLayout = GMMA::ABLayout< 96, 32>; + using CLayout = GMMA::CLayout_64x96; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; @@ -14214,28 +3761,23 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -using SM90_64x16x32_F32E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x16x32_F32E4M3E5M2_SS_TN; +using SM90_64x128x32_S32S8U8_RS_TN = SM90::GMMA::MMA_64x128x32_S32S8U8_RS_TN; -template -struct MMA_Traits> +template <> +struct MMA_Traits { - using ValTypeD = float; - using ValTypeA = float_e4m3_t; - using ValTypeB = float_e5m2_t; - using ValTypeC = float; + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; - using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_16,_32>; + using Shape_MNK = Shape<_64,_128,_32>; using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 32>; - using BLayout = GMMA::ABLayout< 16, 32>; - using CLayout = GMMA::CLayout_64x16; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<128, 32>; + using CLayout = GMMA::CLayout_64x128; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; @@ -14243,27 +3785,23 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -using SM90_64x16x32_F32E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x16x32_F32E4M3E5M2_RS_TN; +using SM90_64x128x32_S32S8U8_RS_TN_SATURATE = SM90::GMMA::MMA_64x128x32_S32S8U8_RS_TN_SATURATE; -template -struct MMA_Traits> +template <> +struct MMA_Traits { - using ValTypeD = float; - using ValTypeA = float_e4m3_t; - using ValTypeB = float_e5m2_t; - using ValTypeC = float; + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_16,_32>; + using Shape_MNK = Shape<_64,_128,_32>; using ThrID = Layout<_128>; using ALayout = GMMA::ALayout_64x32; - using BLayout = GMMA::ABLayout< 16, 32>; - using CLayout = GMMA::CLayout_64x16; + using BLayout = GMMA::ABLayout<128, 32>; + using CLayout = GMMA::CLayout_64x128; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; @@ -14271,28 +3809,23 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -using SM90_64x32x32_F16E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x32x32_F16E4M3E5M2_SS_TN; +using SM90_64x192x32_S32S8U8_RS_TN = SM90::GMMA::MMA_64x192x32_S32S8U8_RS_TN; -template -struct MMA_Traits> +template <> +struct MMA_Traits { - using ValTypeD = half_t; - using ValTypeA = float_e4m3_t; - using ValTypeB = float_e5m2_t; - using ValTypeC = half_t; + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; - using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_32,_32>; + using Shape_MNK = Shape<_64,_192,_32>; using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 32>; - using BLayout = GMMA::ABLayout< 32, 32>; - using CLayout = GMMA::CLayout_64x32; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<192, 32>; + using CLayout = GMMA::CLayout_64x192; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; @@ -14300,27 +3833,23 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -using SM90_64x32x32_F16E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x32x32_F16E4M3E5M2_RS_TN; +using SM90_64x192x32_S32S8U8_RS_TN_SATURATE = SM90::GMMA::MMA_64x192x32_S32S8U8_RS_TN_SATURATE; -template -struct MMA_Traits> +template <> +struct MMA_Traits { - using ValTypeD = half_t; - using ValTypeA = float_e4m3_t; - using ValTypeB = float_e5m2_t; - using ValTypeC = half_t; + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_32,_32>; + using Shape_MNK = Shape<_64,_192,_32>; using ThrID = Layout<_128>; using ALayout = GMMA::ALayout_64x32; - using BLayout = GMMA::ABLayout< 32, 32>; - using CLayout = GMMA::CLayout_64x32; + using BLayout = GMMA::ABLayout<192, 32>; + using CLayout = GMMA::CLayout_64x192; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; @@ -14328,28 +3857,23 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -using SM90_64x32x32_F32E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x32x32_F32E4M3E5M2_SS_TN; +using SM90_64x256x32_S32S8U8_RS_TN = SM90::GMMA::MMA_64x256x32_S32S8U8_RS_TN; -template -struct MMA_Traits> +template <> +struct MMA_Traits { - using ValTypeD = float; - using ValTypeA = float_e4m3_t; - using ValTypeB = float_e5m2_t; - using ValTypeC = float; + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; - using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_32,_32>; + using Shape_MNK = Shape<_64,_256,_32>; using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 32>; - using BLayout = GMMA::ABLayout< 32, 32>; - using CLayout = GMMA::CLayout_64x32; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<256, 32>; + using CLayout = GMMA::CLayout_64x256; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; @@ -14357,178 +3881,148 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -using SM90_64x32x32_F32E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x32x32_F32E4M3E5M2_RS_TN; +using SM90_64x256x32_S32S8U8_RS_TN_SATURATE = SM90::GMMA::MMA_64x256x32_S32S8U8_RS_TN_SATURATE; -template -struct MMA_Traits> +template <> +struct MMA_Traits { - using ValTypeD = float; - using ValTypeA = float_e4m3_t; - using ValTypeB = float_e5m2_t; - using ValTypeC = float; + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_32,_32>; + using Shape_MNK = Shape<_64,_256,_32>; using ThrID = Layout<_128>; using ALayout = GMMA::ALayout_64x32; - using BLayout = GMMA::ABLayout< 32, 32>; - using CLayout = GMMA::CLayout_64x32; + using BLayout = GMMA::ABLayout<256, 32>; + using CLayout = GMMA::CLayout_64x256; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -using SM90_64x48x32_F16E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x48x32_F16E4M3E5M2_SS_TN; +using SM90_64x8x32_S32U8S8_SS_TN = SM90::GMMA::MMA_64x8x32_S32U8S8_SS_TN; -template -struct MMA_Traits> +template <> +struct MMA_Traits { - using ValTypeD = half_t; - using ValTypeA = float_e4m3_t; - using ValTypeB = float_e5m2_t; - using ValTypeC = half_t; + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_48,_32>; + using Shape_MNK = Shape<_64,_8,_32>; using ThrID = Layout<_128>; using ALayout = GMMA::ABLayout< 64, 32>; - using BLayout = GMMA::ABLayout< 48, 32>; - using CLayout = GMMA::CLayout_64x48; + using BLayout = GMMA::ABLayout< 8, 32>; + using CLayout = GMMA::CLayout_64x8; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -using SM90_64x48x32_F16E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x48x32_F16E4M3E5M2_RS_TN; +using SM90_64x8x32_S32U8S8_SS_TN_SATURATE = SM90::GMMA::MMA_64x8x32_S32U8S8_SS_TN_SATURATE; -template -struct MMA_Traits> +template <> +struct MMA_Traits { - using ValTypeD = half_t; - using ValTypeA = float_e4m3_t; - using ValTypeB = float_e5m2_t; - using ValTypeC = half_t; + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_48,_32>; + using Shape_MNK = Shape<_64,_8,_32>; using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x32; - using BLayout = GMMA::ABLayout< 48, 32>; - using CLayout = GMMA::CLayout_64x48; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 8, 32>; + using CLayout = GMMA::CLayout_64x8; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -using SM90_64x48x32_F32E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x48x32_F32E4M3E5M2_SS_TN; +using SM90_64x16x32_S32U8S8_SS_TN = SM90::GMMA::MMA_64x16x32_S32U8S8_SS_TN; -template -struct MMA_Traits> +template <> +struct MMA_Traits { - using ValTypeD = float; - using ValTypeA = float_e4m3_t; - using ValTypeB = float_e5m2_t; - using ValTypeC = float; + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_48,_32>; + using Shape_MNK = Shape<_64,_16,_32>; using ThrID = Layout<_128>; using ALayout = GMMA::ABLayout< 64, 32>; - using BLayout = GMMA::ABLayout< 48, 32>; - using CLayout = GMMA::CLayout_64x48; + using BLayout = GMMA::ABLayout< 16, 32>; + using CLayout = GMMA::CLayout_64x16; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -using SM90_64x48x32_F32E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x48x32_F32E4M3E5M2_RS_TN; +using SM90_64x16x32_S32U8S8_SS_TN_SATURATE = SM90::GMMA::MMA_64x16x32_S32U8S8_SS_TN_SATURATE; -template -struct MMA_Traits> +template <> +struct MMA_Traits { - using ValTypeD = float; - using ValTypeA = float_e4m3_t; - using ValTypeB = float_e5m2_t; - using ValTypeC = float; + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_48,_32>; + using Shape_MNK = Shape<_64,_16,_32>; using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x32; - using BLayout = GMMA::ABLayout< 48, 32>; - using CLayout = GMMA::CLayout_64x48; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 16, 32>; + using CLayout = GMMA::CLayout_64x16; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -using SM90_64x64x32_F16E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x64x32_F16E4M3E5M2_SS_TN; +using SM90_64x32x32_S32U8S8_SS_TN = SM90::GMMA::MMA_64x32x32_S32U8S8_SS_TN; -template -struct MMA_Traits> +template <> +struct MMA_Traits { - using ValTypeD = half_t; - using ValTypeA = float_e4m3_t; - using ValTypeB = float_e5m2_t; - using ValTypeC = half_t; + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_64,_32>; + using Shape_MNK = Shape<_64,_32,_32>; using ThrID = Layout<_128>; using ALayout = GMMA::ABLayout< 64, 32>; - using BLayout = GMMA::ABLayout< 64, 32>; - using CLayout = GMMA::CLayout_64x64; + using BLayout = GMMA::ABLayout< 32, 32>; + using CLayout = GMMA::CLayout_64x32; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; @@ -14536,27 +4030,24 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -using SM90_64x64x32_F16E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x64x32_F16E4M3E5M2_RS_TN; +using SM90_64x32x32_S32U8S8_SS_TN_SATURATE = SM90::GMMA::MMA_64x32x32_S32U8S8_SS_TN_SATURATE; -template -struct MMA_Traits> +template <> +struct MMA_Traits { - using ValTypeD = half_t; - using ValTypeA = float_e4m3_t; - using ValTypeB = float_e5m2_t; - using ValTypeC = half_t; + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_64,_32>; + using Shape_MNK = Shape<_64,_32,_32>; using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x32; - using BLayout = GMMA::ABLayout< 64, 32>; - using CLayout = GMMA::CLayout_64x64; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 32, 32>; + using CLayout = GMMA::CLayout_64x32; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; @@ -14564,19 +4055,15 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -using SM90_64x64x32_F32E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x64x32_F32E4M3E5M2_SS_TN; +using SM90_64x64x32_S32U8S8_SS_TN = SM90::GMMA::MMA_64x64x32_S32U8S8_SS_TN; -template -struct MMA_Traits> +template <> +struct MMA_Traits { - using ValTypeD = float; - using ValTypeA = float_e4m3_t; - using ValTypeB = float_e5m2_t; - using ValTypeC = float; + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; @@ -14593,25 +4080,22 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -using SM90_64x64x32_F32E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x64x32_F32E4M3E5M2_RS_TN; +using SM90_64x64x32_S32U8S8_SS_TN_SATURATE = SM90::GMMA::MMA_64x64x32_S32U8S8_SS_TN_SATURATE; -template -struct MMA_Traits> +template <> +struct MMA_Traits { - using ValTypeD = float; - using ValTypeA = float_e4m3_t; - using ValTypeB = float_e5m2_t; - using ValTypeC = float; + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; using Shape_MNK = Shape<_64,_64,_32>; using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x32; + using ALayout = GMMA::ABLayout< 64, 32>; using BLayout = GMMA::ABLayout< 64, 32>; using CLayout = GMMA::CLayout_64x64; @@ -14620,151 +4104,125 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -using SM90_64x80x32_F16E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x80x32_F16E4M3E5M2_SS_TN; +using SM90_64x96x32_S32U8S8_SS_TN = SM90::GMMA::MMA_64x96x32_S32U8S8_SS_TN; -template -struct MMA_Traits> +template <> +struct MMA_Traits { - using ValTypeD = half_t; - using ValTypeA = float_e4m3_t; - using ValTypeB = float_e5m2_t; - using ValTypeC = half_t; + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_80,_32>; + using Shape_MNK = Shape<_64,_96,_32>; using ThrID = Layout<_128>; using ALayout = GMMA::ABLayout< 64, 32>; - using BLayout = GMMA::ABLayout< 80, 32>; - using CLayout = GMMA::CLayout_64x80; + using BLayout = GMMA::ABLayout< 96, 32>; + using CLayout = GMMA::CLayout_64x96; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -using SM90_64x80x32_F16E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x80x32_F16E4M3E5M2_RS_TN; +using SM90_64x96x32_S32U8S8_SS_TN_SATURATE = SM90::GMMA::MMA_64x96x32_S32U8S8_SS_TN_SATURATE; -template -struct MMA_Traits> +template <> +struct MMA_Traits { - using ValTypeD = half_t; - using ValTypeA = float_e4m3_t; - using ValTypeB = float_e5m2_t; - using ValTypeC = half_t; + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_80,_32>; + using Shape_MNK = Shape<_64,_96,_32>; using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x32; - using BLayout = GMMA::ABLayout< 80, 32>; - using CLayout = GMMA::CLayout_64x80; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 96, 32>; + using CLayout = GMMA::CLayout_64x96; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -using SM90_64x80x32_F32E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x80x32_F32E4M3E5M2_SS_TN; +using SM90_64x128x32_S32U8S8_SS_TN = SM90::GMMA::MMA_64x128x32_S32U8S8_SS_TN; -template -struct MMA_Traits> +template <> +struct MMA_Traits { - using ValTypeD = float; - using ValTypeA = float_e4m3_t; - using ValTypeB = float_e5m2_t; - using ValTypeC = float; + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_80,_32>; + using Shape_MNK = Shape<_64,_128,_32>; using ThrID = Layout<_128>; using ALayout = GMMA::ABLayout< 64, 32>; - using BLayout = GMMA::ABLayout< 80, 32>; - using CLayout = GMMA::CLayout_64x80; + using BLayout = GMMA::ABLayout<128, 32>; + using CLayout = GMMA::CLayout_64x128; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -using SM90_64x80x32_F32E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x80x32_F32E4M3E5M2_RS_TN; +using SM90_64x128x32_S32U8S8_SS_TN_SATURATE = SM90::GMMA::MMA_64x128x32_S32U8S8_SS_TN_SATURATE; -template -struct MMA_Traits> +template <> +struct MMA_Traits { - using ValTypeD = float; - using ValTypeA = float_e4m3_t; - using ValTypeB = float_e5m2_t; - using ValTypeC = float; + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_80,_32>; + using Shape_MNK = Shape<_64,_128,_32>; using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x32; - using BLayout = GMMA::ABLayout< 80, 32>; - using CLayout = GMMA::CLayout_64x80; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<128, 32>; + using CLayout = GMMA::CLayout_64x128; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -using SM90_64x96x32_F16E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x96x32_F16E4M3E5M2_SS_TN; +using SM90_64x192x32_S32U8S8_SS_TN = SM90::GMMA::MMA_64x192x32_S32U8S8_SS_TN; -template -struct MMA_Traits> +template <> +struct MMA_Traits { - using ValTypeD = half_t; - using ValTypeA = float_e4m3_t; - using ValTypeB = float_e5m2_t; - using ValTypeC = half_t; + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_96,_32>; + using Shape_MNK = Shape<_64,_192,_32>; using ThrID = Layout<_128>; using ALayout = GMMA::ABLayout< 64, 32>; - using BLayout = GMMA::ABLayout< 96, 32>; - using CLayout = GMMA::CLayout_64x96; + using BLayout = GMMA::ABLayout<192, 32>; + using CLayout = GMMA::CLayout_64x192; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; @@ -14772,27 +4230,24 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -using SM90_64x96x32_F16E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x96x32_F16E4M3E5M2_RS_TN; +using SM90_64x192x32_S32U8S8_SS_TN_SATURATE = SM90::GMMA::MMA_64x192x32_S32U8S8_SS_TN_SATURATE; -template -struct MMA_Traits> +template <> +struct MMA_Traits { - using ValTypeD = half_t; - using ValTypeA = float_e4m3_t; - using ValTypeB = float_e5m2_t; - using ValTypeC = half_t; + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_96,_32>; + using Shape_MNK = Shape<_64,_192,_32>; using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x32; - using BLayout = GMMA::ABLayout< 96, 32>; - using CLayout = GMMA::CLayout_64x96; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<192, 32>; + using CLayout = GMMA::CLayout_64x192; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; @@ -14800,28 +4255,24 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -using SM90_64x96x32_F32E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x96x32_F32E4M3E5M2_SS_TN; +using SM90_64x256x32_S32U8S8_SS_TN = SM90::GMMA::MMA_64x256x32_S32U8S8_SS_TN; -template -struct MMA_Traits> +template <> +struct MMA_Traits { - using ValTypeD = float; - using ValTypeA = float_e4m3_t; - using ValTypeB = float_e5m2_t; - using ValTypeC = float; + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_96,_32>; + using Shape_MNK = Shape<_64,_256,_32>; using ThrID = Layout<_128>; using ALayout = GMMA::ABLayout< 64, 32>; - using BLayout = GMMA::ABLayout< 96, 32>; - using CLayout = GMMA::CLayout_64x96; + using BLayout = GMMA::ABLayout<256, 32>; + using CLayout = GMMA::CLayout_64x256; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; @@ -14829,178 +4280,144 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -using SM90_64x96x32_F32E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x96x32_F32E4M3E5M2_RS_TN; +using SM90_64x256x32_S32U8S8_SS_TN_SATURATE = SM90::GMMA::MMA_64x256x32_S32U8S8_SS_TN_SATURATE; -template -struct MMA_Traits> +template <> +struct MMA_Traits { - using ValTypeD = float; - using ValTypeA = float_e4m3_t; - using ValTypeB = float_e5m2_t; - using ValTypeC = float; + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_96,_32>; + using Shape_MNK = Shape<_64,_256,_32>; using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x32; - using BLayout = GMMA::ABLayout< 96, 32>; - using CLayout = GMMA::CLayout_64x96; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<256, 32>; + using CLayout = GMMA::CLayout_64x256; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -using SM90_64x112x32_F16E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x112x32_F16E4M3E5M2_SS_TN; +using SM90_64x8x32_S32U8S8_RS_TN = SM90::GMMA::MMA_64x8x32_S32U8S8_RS_TN; -template -struct MMA_Traits> +template <> +struct MMA_Traits { - using ValTypeD = half_t; - using ValTypeA = float_e4m3_t; - using ValTypeB = float_e5m2_t; - using ValTypeC = half_t; + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; - using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_112,_32>; + using Shape_MNK = Shape<_64,_8,_32>; using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 32>; - using BLayout = GMMA::ABLayout<112, 32>; - using CLayout = GMMA::CLayout_64x112; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 8, 32>; + using CLayout = GMMA::CLayout_64x8; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -using SM90_64x112x32_F16E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x112x32_F16E4M3E5M2_RS_TN; +using SM90_64x8x32_S32U8S8_RS_TN_SATURATE = SM90::GMMA::MMA_64x8x32_S32U8S8_RS_TN_SATURATE; -template -struct MMA_Traits> +template <> +struct MMA_Traits { - using ValTypeD = half_t; - using ValTypeA = float_e4m3_t; - using ValTypeB = float_e5m2_t; - using ValTypeC = half_t; + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_112,_32>; + using Shape_MNK = Shape<_64,_8,_32>; using ThrID = Layout<_128>; using ALayout = GMMA::ALayout_64x32; - using BLayout = GMMA::ABLayout<112, 32>; - using CLayout = GMMA::CLayout_64x112; + using BLayout = GMMA::ABLayout< 8, 32>; + using CLayout = GMMA::CLayout_64x8; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -using SM90_64x112x32_F32E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x112x32_F32E4M3E5M2_SS_TN; +using SM90_64x16x32_S32U8S8_RS_TN = SM90::GMMA::MMA_64x16x32_S32U8S8_RS_TN; -template -struct MMA_Traits> +template <> +struct MMA_Traits { - using ValTypeD = float; - using ValTypeA = float_e4m3_t; - using ValTypeB = float_e5m2_t; - using ValTypeC = float; + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; - using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_112,_32>; + using Shape_MNK = Shape<_64,_16,_32>; using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 32>; - using BLayout = GMMA::ABLayout<112, 32>; - using CLayout = GMMA::CLayout_64x112; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 16, 32>; + using CLayout = GMMA::CLayout_64x16; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -using SM90_64x112x32_F32E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x112x32_F32E4M3E5M2_RS_TN; +using SM90_64x16x32_S32U8S8_RS_TN_SATURATE = SM90::GMMA::MMA_64x16x32_S32U8S8_RS_TN_SATURATE; -template -struct MMA_Traits> +template <> +struct MMA_Traits { - using ValTypeD = float; - using ValTypeA = float_e4m3_t; - using ValTypeB = float_e5m2_t; - using ValTypeC = float; + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_112,_32>; + using Shape_MNK = Shape<_64,_16,_32>; using ThrID = Layout<_128>; using ALayout = GMMA::ALayout_64x32; - using BLayout = GMMA::ABLayout<112, 32>; - using CLayout = GMMA::CLayout_64x112; + using BLayout = GMMA::ABLayout< 16, 32>; + using CLayout = GMMA::CLayout_64x16; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -using SM90_64x128x32_F16E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x128x32_F16E4M3E5M2_SS_TN; +using SM90_64x32x32_S32U8S8_RS_TN = SM90::GMMA::MMA_64x32x32_S32U8S8_RS_TN; -template -struct MMA_Traits> +template <> +struct MMA_Traits { - using ValTypeD = half_t; - using ValTypeA = float_e4m3_t; - using ValTypeB = float_e5m2_t; - using ValTypeC = half_t; + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; - using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_128,_32>; + using Shape_MNK = Shape<_64,_32,_32>; using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 32>; - using BLayout = GMMA::ABLayout<128, 32>; - using CLayout = GMMA::CLayout_64x128; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 32, 32>; + using CLayout = GMMA::CLayout_64x32; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; @@ -15008,27 +4425,23 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -using SM90_64x128x32_F16E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x128x32_F16E4M3E5M2_RS_TN; +using SM90_64x32x32_S32U8S8_RS_TN_SATURATE = SM90::GMMA::MMA_64x32x32_S32U8S8_RS_TN_SATURATE; -template -struct MMA_Traits> +template <> +struct MMA_Traits { - using ValTypeD = half_t; - using ValTypeA = float_e4m3_t; - using ValTypeB = float_e5m2_t; - using ValTypeC = half_t; + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_128,_32>; + using Shape_MNK = Shape<_64,_32,_32>; using ThrID = Layout<_128>; using ALayout = GMMA::ALayout_64x32; - using BLayout = GMMA::ABLayout<128, 32>; - using CLayout = GMMA::CLayout_64x128; + using BLayout = GMMA::ABLayout< 32, 32>; + using CLayout = GMMA::CLayout_64x32; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; @@ -15036,28 +4449,23 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -using SM90_64x128x32_F32E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x128x32_F32E4M3E5M2_SS_TN; +using SM90_64x64x32_S32U8S8_RS_TN = SM90::GMMA::MMA_64x64x32_S32U8S8_RS_TN; -template -struct MMA_Traits> +template <> +struct MMA_Traits { - using ValTypeD = float; - using ValTypeA = float_e4m3_t; - using ValTypeB = float_e5m2_t; - using ValTypeC = float; + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; - using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_128,_32>; + using Shape_MNK = Shape<_64,_64,_32>; using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 32>; - using BLayout = GMMA::ABLayout<128, 32>; - using CLayout = GMMA::CLayout_64x128; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 64, 32>; + using CLayout = GMMA::CLayout_64x64; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; @@ -15065,422 +4473,340 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -using SM90_64x128x32_F32E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x128x32_F32E4M3E5M2_RS_TN; +using SM90_64x64x32_S32U8S8_RS_TN_SATURATE = SM90::GMMA::MMA_64x64x32_S32U8S8_RS_TN_SATURATE; -template -struct MMA_Traits> +template <> +struct MMA_Traits { - using ValTypeD = float; - using ValTypeA = float_e4m3_t; - using ValTypeB = float_e5m2_t; - using ValTypeC = float; + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_128,_32>; + using Shape_MNK = Shape<_64,_64,_32>; using ThrID = Layout<_128>; using ALayout = GMMA::ALayout_64x32; - using BLayout = GMMA::ABLayout<128, 32>; - using CLayout = GMMA::CLayout_64x128; + using BLayout = GMMA::ABLayout< 64, 32>; + using CLayout = GMMA::CLayout_64x64; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) - -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -using SM90_64x144x32_F16E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x144x32_F16E4M3E5M2_SS_TN; -template -struct MMA_Traits> +using SM90_64x96x32_S32U8S8_RS_TN = SM90::GMMA::MMA_64x96x32_S32U8S8_RS_TN; + +template <> +struct MMA_Traits { - using ValTypeD = half_t; - using ValTypeA = float_e4m3_t; - using ValTypeB = float_e5m2_t; - using ValTypeC = half_t; + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; - using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_144,_32>; + using Shape_MNK = Shape<_64,_96,_32>; using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 32>; - using BLayout = GMMA::ABLayout<144, 32>; - using CLayout = GMMA::CLayout_64x144; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 96, 32>; + using CLayout = GMMA::CLayout_64x96; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -using SM90_64x144x32_F16E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x144x32_F16E4M3E5M2_RS_TN; +using SM90_64x96x32_S32U8S8_RS_TN_SATURATE = SM90::GMMA::MMA_64x96x32_S32U8S8_RS_TN_SATURATE; -template -struct MMA_Traits> +template <> +struct MMA_Traits { - using ValTypeD = half_t; - using ValTypeA = float_e4m3_t; - using ValTypeB = float_e5m2_t; - using ValTypeC = half_t; + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_144,_32>; + using Shape_MNK = Shape<_64,_96,_32>; using ThrID = Layout<_128>; using ALayout = GMMA::ALayout_64x32; - using BLayout = GMMA::ABLayout<144, 32>; - using CLayout = GMMA::CLayout_64x144; + using BLayout = GMMA::ABLayout< 96, 32>; + using CLayout = GMMA::CLayout_64x96; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -using SM90_64x144x32_F32E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x144x32_F32E4M3E5M2_SS_TN; +using SM90_64x128x32_S32U8S8_RS_TN = SM90::GMMA::MMA_64x128x32_S32U8S8_RS_TN; -template -struct MMA_Traits> +template <> +struct MMA_Traits { - using ValTypeD = float; - using ValTypeA = float_e4m3_t; - using ValTypeB = float_e5m2_t; - using ValTypeC = float; + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; - using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_144,_32>; + using Shape_MNK = Shape<_64,_128,_32>; using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 32>; - using BLayout = GMMA::ABLayout<144, 32>; - using CLayout = GMMA::CLayout_64x144; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<128, 32>; + using CLayout = GMMA::CLayout_64x128; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -using SM90_64x144x32_F32E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x144x32_F32E4M3E5M2_RS_TN; +using SM90_64x128x32_S32U8S8_RS_TN_SATURATE = SM90::GMMA::MMA_64x128x32_S32U8S8_RS_TN_SATURATE; -template -struct MMA_Traits> +template <> +struct MMA_Traits { - using ValTypeD = float; - using ValTypeA = float_e4m3_t; - using ValTypeB = float_e5m2_t; - using ValTypeC = float; + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_144,_32>; + using Shape_MNK = Shape<_64,_128,_32>; using ThrID = Layout<_128>; using ALayout = GMMA::ALayout_64x32; - using BLayout = GMMA::ABLayout<144, 32>; - using CLayout = GMMA::CLayout_64x144; + using BLayout = GMMA::ABLayout<128, 32>; + using CLayout = GMMA::CLayout_64x128; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -using SM90_64x160x32_F16E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x160x32_F16E4M3E5M2_SS_TN; +using SM90_64x192x32_S32U8S8_RS_TN = SM90::GMMA::MMA_64x192x32_S32U8S8_RS_TN; -template -struct MMA_Traits> +template <> +struct MMA_Traits { - using ValTypeD = half_t; - using ValTypeA = float_e4m3_t; - using ValTypeB = float_e5m2_t; - using ValTypeC = half_t; + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; - using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_160,_32>; + using Shape_MNK = Shape<_64,_192,_32>; using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 32>; - using BLayout = GMMA::ABLayout<160, 32>; - using CLayout = GMMA::CLayout_64x160; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<192, 32>; + using CLayout = GMMA::CLayout_64x192; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -using SM90_64x160x32_F16E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x160x32_F16E4M3E5M2_RS_TN; +using SM90_64x192x32_S32U8S8_RS_TN_SATURATE = SM90::GMMA::MMA_64x192x32_S32U8S8_RS_TN_SATURATE; -template -struct MMA_Traits> +template <> +struct MMA_Traits { - using ValTypeD = half_t; - using ValTypeA = float_e4m3_t; - using ValTypeB = float_e5m2_t; - using ValTypeC = half_t; + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_160,_32>; + using Shape_MNK = Shape<_64,_192,_32>; using ThrID = Layout<_128>; using ALayout = GMMA::ALayout_64x32; - using BLayout = GMMA::ABLayout<160, 32>; - using CLayout = GMMA::CLayout_64x160; + using BLayout = GMMA::ABLayout<192, 32>; + using CLayout = GMMA::CLayout_64x192; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -using SM90_64x160x32_F32E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x160x32_F32E4M3E5M2_SS_TN; +using SM90_64x256x32_S32U8S8_RS_TN = SM90::GMMA::MMA_64x256x32_S32U8S8_RS_TN; -template -struct MMA_Traits> +template <> +struct MMA_Traits { - using ValTypeD = float; - using ValTypeA = float_e4m3_t; - using ValTypeB = float_e5m2_t; - using ValTypeC = float; + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; - using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_160,_32>; + using Shape_MNK = Shape<_64,_256,_32>; using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 32>; - using BLayout = GMMA::ABLayout<160, 32>; - using CLayout = GMMA::CLayout_64x160; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<256, 32>; + using CLayout = GMMA::CLayout_64x256; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -using SM90_64x160x32_F32E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x160x32_F32E4M3E5M2_RS_TN; +using SM90_64x256x32_S32U8S8_RS_TN_SATURATE = SM90::GMMA::MMA_64x256x32_S32U8S8_RS_TN_SATURATE; -template -struct MMA_Traits> +template <> +struct MMA_Traits { - using ValTypeD = float; - using ValTypeA = float_e4m3_t; - using ValTypeB = float_e5m2_t; - using ValTypeC = float; + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_160,_32>; + using Shape_MNK = Shape<_64,_256,_32>; using ThrID = Layout<_128>; using ALayout = GMMA::ALayout_64x32; - using BLayout = GMMA::ABLayout<160, 32>; - using CLayout = GMMA::CLayout_64x160; + using BLayout = GMMA::ABLayout<256, 32>; + using CLayout = GMMA::CLayout_64x256; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -using SM90_64x176x32_F16E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x176x32_F16E4M3E5M2_SS_TN; +using SM90_64x8x32_S32U8U8_SS_TN = SM90::GMMA::MMA_64x8x32_S32U8U8_SS_TN; -template -struct MMA_Traits> +template <> +struct MMA_Traits { - using ValTypeD = half_t; - using ValTypeA = float_e4m3_t; - using ValTypeB = float_e5m2_t; - using ValTypeC = half_t; + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_176,_32>; + using Shape_MNK = Shape<_64,_8,_32>; using ThrID = Layout<_128>; using ALayout = GMMA::ABLayout< 64, 32>; - using BLayout = GMMA::ABLayout<176, 32>; - using CLayout = GMMA::CLayout_64x176; + using BLayout = GMMA::ABLayout< 8, 32>; + using CLayout = GMMA::CLayout_64x8; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -using SM90_64x176x32_F16E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x176x32_F16E4M3E5M2_RS_TN; +using SM90_64x8x32_S32U8U8_SS_TN_SATURATE = SM90::GMMA::MMA_64x8x32_S32U8U8_SS_TN_SATURATE; -template -struct MMA_Traits> +template <> +struct MMA_Traits { - using ValTypeD = half_t; - using ValTypeA = float_e4m3_t; - using ValTypeB = float_e5m2_t; - using ValTypeC = half_t; + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_176,_32>; + using Shape_MNK = Shape<_64,_8,_32>; using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x32; - using BLayout = GMMA::ABLayout<176, 32>; - using CLayout = GMMA::CLayout_64x176; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 8, 32>; + using CLayout = GMMA::CLayout_64x8; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -using SM90_64x176x32_F32E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x176x32_F32E4M3E5M2_SS_TN; +using SM90_64x16x32_S32U8U8_SS_TN = SM90::GMMA::MMA_64x16x32_S32U8U8_SS_TN; -template -struct MMA_Traits> +template <> +struct MMA_Traits { - using ValTypeD = float; - using ValTypeA = float_e4m3_t; - using ValTypeB = float_e5m2_t; - using ValTypeC = float; + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_176,_32>; + using Shape_MNK = Shape<_64,_16,_32>; using ThrID = Layout<_128>; using ALayout = GMMA::ABLayout< 64, 32>; - using BLayout = GMMA::ABLayout<176, 32>; - using CLayout = GMMA::CLayout_64x176; + using BLayout = GMMA::ABLayout< 16, 32>; + using CLayout = GMMA::CLayout_64x16; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -using SM90_64x176x32_F32E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x176x32_F32E4M3E5M2_RS_TN; +using SM90_64x16x32_S32U8U8_SS_TN_SATURATE = SM90::GMMA::MMA_64x16x32_S32U8U8_SS_TN_SATURATE; -template -struct MMA_Traits> -{ - using ValTypeD = float; - using ValTypeA = float_e4m3_t; - using ValTypeB = float_e5m2_t; - using ValTypeC = float; +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_176,_32>; + using Shape_MNK = Shape<_64,_16,_32>; using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x32; - using BLayout = GMMA::ABLayout<176, 32>; - using CLayout = GMMA::CLayout_64x176; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 16, 32>; + using CLayout = GMMA::CLayout_64x16; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -using SM90_64x192x32_F16E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x192x32_F16E4M3E5M2_SS_TN; +using SM90_64x32x32_S32U8U8_SS_TN = SM90::GMMA::MMA_64x32x32_S32U8U8_SS_TN; -template -struct MMA_Traits> +template <> +struct MMA_Traits { - using ValTypeD = half_t; - using ValTypeA = float_e4m3_t; - using ValTypeB = float_e5m2_t; - using ValTypeC = half_t; + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_192,_32>; + using Shape_MNK = Shape<_64,_32,_32>; using ThrID = Layout<_128>; using ALayout = GMMA::ABLayout< 64, 32>; - using BLayout = GMMA::ABLayout<192, 32>; - using CLayout = GMMA::CLayout_64x192; + using BLayout = GMMA::ABLayout< 32, 32>; + using CLayout = GMMA::CLayout_64x32; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; @@ -15488,27 +4814,24 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -using SM90_64x192x32_F16E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x192x32_F16E4M3E5M2_RS_TN; +using SM90_64x32x32_S32U8U8_SS_TN_SATURATE = SM90::GMMA::MMA_64x32x32_S32U8U8_SS_TN_SATURATE; -template -struct MMA_Traits> +template <> +struct MMA_Traits { - using ValTypeD = half_t; - using ValTypeA = float_e4m3_t; - using ValTypeB = float_e5m2_t; - using ValTypeC = half_t; + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_192,_32>; + using Shape_MNK = Shape<_64,_32,_32>; using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x32; - using BLayout = GMMA::ABLayout<192, 32>; - using CLayout = GMMA::CLayout_64x192; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 32, 32>; + using CLayout = GMMA::CLayout_64x32; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; @@ -15516,28 +4839,24 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -using SM90_64x192x32_F32E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x192x32_F32E4M3E5M2_SS_TN; +using SM90_64x64x32_S32U8U8_SS_TN = SM90::GMMA::MMA_64x64x32_S32U8U8_SS_TN; -template -struct MMA_Traits> +template <> +struct MMA_Traits { - using ValTypeD = float; - using ValTypeA = float_e4m3_t; - using ValTypeB = float_e5m2_t; - using ValTypeC = float; + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_192,_32>; + using Shape_MNK = Shape<_64,_64,_32>; using ThrID = Layout<_128>; using ALayout = GMMA::ABLayout< 64, 32>; - using BLayout = GMMA::ABLayout<192, 32>; - using CLayout = GMMA::CLayout_64x192; + using BLayout = GMMA::ABLayout< 64, 32>; + using CLayout = GMMA::CLayout_64x64; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; @@ -15545,422 +4864,344 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -using SM90_64x192x32_F32E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x192x32_F32E4M3E5M2_RS_TN; +using SM90_64x64x32_S32U8U8_SS_TN_SATURATE = SM90::GMMA::MMA_64x64x32_S32U8U8_SS_TN_SATURATE; -template -struct MMA_Traits> +template <> +struct MMA_Traits { - using ValTypeD = float; - using ValTypeA = float_e4m3_t; - using ValTypeB = float_e5m2_t; - using ValTypeC = float; + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_192,_32>; + using Shape_MNK = Shape<_64,_64,_32>; using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x32; - using BLayout = GMMA::ABLayout<192, 32>; - using CLayout = GMMA::CLayout_64x192; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 64, 32>; + using CLayout = GMMA::CLayout_64x64; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -using SM90_64x208x32_F16E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x208x32_F16E4M3E5M2_SS_TN; +using SM90_64x96x32_S32U8U8_SS_TN = SM90::GMMA::MMA_64x96x32_S32U8U8_SS_TN; -template -struct MMA_Traits> +template <> +struct MMA_Traits { - using ValTypeD = half_t; - using ValTypeA = float_e4m3_t; - using ValTypeB = float_e5m2_t; - using ValTypeC = half_t; + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_208,_32>; + using Shape_MNK = Shape<_64,_96,_32>; using ThrID = Layout<_128>; using ALayout = GMMA::ABLayout< 64, 32>; - using BLayout = GMMA::ABLayout<208, 32>; - using CLayout = GMMA::CLayout_64x208; + using BLayout = GMMA::ABLayout< 96, 32>; + using CLayout = GMMA::CLayout_64x96; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -using SM90_64x208x32_F16E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x208x32_F16E4M3E5M2_RS_TN; +using SM90_64x96x32_S32U8U8_SS_TN_SATURATE = SM90::GMMA::MMA_64x96x32_S32U8U8_SS_TN_SATURATE; -template -struct MMA_Traits> +template <> +struct MMA_Traits { - using ValTypeD = half_t; - using ValTypeA = float_e4m3_t; - using ValTypeB = float_e5m2_t; - using ValTypeC = half_t; + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_208,_32>; + using Shape_MNK = Shape<_64,_96,_32>; using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x32; - using BLayout = GMMA::ABLayout<208, 32>; - using CLayout = GMMA::CLayout_64x208; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 96, 32>; + using CLayout = GMMA::CLayout_64x96; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -using SM90_64x208x32_F32E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x208x32_F32E4M3E5M2_SS_TN; +using SM90_64x128x32_S32U8U8_SS_TN = SM90::GMMA::MMA_64x128x32_S32U8U8_SS_TN; -template -struct MMA_Traits> +template <> +struct MMA_Traits { - using ValTypeD = float; - using ValTypeA = float_e4m3_t; - using ValTypeB = float_e5m2_t; - using ValTypeC = float; + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_208,_32>; + using Shape_MNK = Shape<_64,_128,_32>; using ThrID = Layout<_128>; using ALayout = GMMA::ABLayout< 64, 32>; - using BLayout = GMMA::ABLayout<208, 32>; - using CLayout = GMMA::CLayout_64x208; + using BLayout = GMMA::ABLayout<128, 32>; + using CLayout = GMMA::CLayout_64x128; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -using SM90_64x208x32_F32E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x208x32_F32E4M3E5M2_RS_TN; +using SM90_64x128x32_S32U8U8_SS_TN_SATURATE = SM90::GMMA::MMA_64x128x32_S32U8U8_SS_TN_SATURATE; -template -struct MMA_Traits> +template <> +struct MMA_Traits { - using ValTypeD = float; - using ValTypeA = float_e4m3_t; - using ValTypeB = float_e5m2_t; - using ValTypeC = float; + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_208,_32>; + using Shape_MNK = Shape<_64,_128,_32>; using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x32; - using BLayout = GMMA::ABLayout<208, 32>; - using CLayout = GMMA::CLayout_64x208; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<128, 32>; + using CLayout = GMMA::CLayout_64x128; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -using SM90_64x224x32_F16E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x224x32_F16E4M3E5M2_SS_TN; +using SM90_64x192x32_S32U8U8_SS_TN = SM90::GMMA::MMA_64x192x32_S32U8U8_SS_TN; -template -struct MMA_Traits> +template <> +struct MMA_Traits { - using ValTypeD = half_t; - using ValTypeA = float_e4m3_t; - using ValTypeB = float_e5m2_t; - using ValTypeC = half_t; + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_224,_32>; + using Shape_MNK = Shape<_64,_192,_32>; using ThrID = Layout<_128>; using ALayout = GMMA::ABLayout< 64, 32>; - using BLayout = GMMA::ABLayout<224, 32>; - using CLayout = GMMA::CLayout_64x224; + using BLayout = GMMA::ABLayout<192, 32>; + using CLayout = GMMA::CLayout_64x192; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -using SM90_64x224x32_F16E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x224x32_F16E4M3E5M2_RS_TN; +using SM90_64x192x32_S32U8U8_SS_TN_SATURATE = SM90::GMMA::MMA_64x192x32_S32U8U8_SS_TN_SATURATE; -template -struct MMA_Traits> +template <> +struct MMA_Traits { - using ValTypeD = half_t; - using ValTypeA = float_e4m3_t; - using ValTypeB = float_e5m2_t; - using ValTypeC = half_t; + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_224,_32>; + using Shape_MNK = Shape<_64,_192,_32>; using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x32; - using BLayout = GMMA::ABLayout<224, 32>; - using CLayout = GMMA::CLayout_64x224; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<192, 32>; + using CLayout = GMMA::CLayout_64x192; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -using SM90_64x224x32_F32E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x224x32_F32E4M3E5M2_SS_TN; +using SM90_64x256x32_S32U8U8_SS_TN = SM90::GMMA::MMA_64x256x32_S32U8U8_SS_TN; -template -struct MMA_Traits> +template <> +struct MMA_Traits { - using ValTypeD = float; - using ValTypeA = float_e4m3_t; - using ValTypeB = float_e5m2_t; - using ValTypeC = float; + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_224,_32>; + using Shape_MNK = Shape<_64,_256,_32>; using ThrID = Layout<_128>; using ALayout = GMMA::ABLayout< 64, 32>; - using BLayout = GMMA::ABLayout<224, 32>; - using CLayout = GMMA::CLayout_64x224; + using BLayout = GMMA::ABLayout<256, 32>; + using CLayout = GMMA::CLayout_64x256; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -using SM90_64x224x32_F32E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x224x32_F32E4M3E5M2_RS_TN; +using SM90_64x256x32_S32U8U8_SS_TN_SATURATE = SM90::GMMA::MMA_64x256x32_S32U8U8_SS_TN_SATURATE; -template -struct MMA_Traits> +template <> +struct MMA_Traits { - using ValTypeD = float; - using ValTypeA = float_e4m3_t; - using ValTypeB = float_e5m2_t; - using ValTypeC = float; + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_224,_32>; + using Shape_MNK = Shape<_64,_256,_32>; using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x32; - using BLayout = GMMA::ABLayout<224, 32>; - using CLayout = GMMA::CLayout_64x224; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<256, 32>; + using CLayout = GMMA::CLayout_64x256; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -using SM90_64x240x32_F16E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x240x32_F16E4M3E5M2_SS_TN; +using SM90_64x8x32_S32U8U8_RS_TN = SM90::GMMA::MMA_64x8x32_S32U8U8_RS_TN; -template -struct MMA_Traits> +template <> +struct MMA_Traits { - using ValTypeD = half_t; - using ValTypeA = float_e4m3_t; - using ValTypeB = float_e5m2_t; - using ValTypeC = half_t; + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; - using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_240,_32>; + using Shape_MNK = Shape<_64,_8,_32>; using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 32>; - using BLayout = GMMA::ABLayout<240, 32>; - using CLayout = GMMA::CLayout_64x240; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 8, 32>; + using CLayout = GMMA::CLayout_64x8; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -using SM90_64x240x32_F16E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x240x32_F16E4M3E5M2_RS_TN; +using SM90_64x8x32_S32U8U8_RS_TN_SATURATE = SM90::GMMA::MMA_64x8x32_S32U8U8_RS_TN_SATURATE; -template -struct MMA_Traits> +template <> +struct MMA_Traits { - using ValTypeD = half_t; - using ValTypeA = float_e4m3_t; - using ValTypeB = float_e5m2_t; - using ValTypeC = half_t; + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_240,_32>; + using Shape_MNK = Shape<_64,_8,_32>; using ThrID = Layout<_128>; using ALayout = GMMA::ALayout_64x32; - using BLayout = GMMA::ABLayout<240, 32>; - using CLayout = GMMA::CLayout_64x240; + using BLayout = GMMA::ABLayout< 8, 32>; + using CLayout = GMMA::CLayout_64x8; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -using SM90_64x240x32_F32E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x240x32_F32E4M3E5M2_SS_TN; +using SM90_64x16x32_S32U8U8_RS_TN = SM90::GMMA::MMA_64x16x32_S32U8U8_RS_TN; -template -struct MMA_Traits> +template <> +struct MMA_Traits { - using ValTypeD = float; - using ValTypeA = float_e4m3_t; - using ValTypeB = float_e5m2_t; - using ValTypeC = float; + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; - using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_240,_32>; + using Shape_MNK = Shape<_64,_16,_32>; using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 32>; - using BLayout = GMMA::ABLayout<240, 32>; - using CLayout = GMMA::CLayout_64x240; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 16, 32>; + using CLayout = GMMA::CLayout_64x16; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -using SM90_64x240x32_F32E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x240x32_F32E4M3E5M2_RS_TN; +using SM90_64x16x32_S32U8U8_RS_TN_SATURATE = SM90::GMMA::MMA_64x16x32_S32U8U8_RS_TN_SATURATE; -template -struct MMA_Traits> +template <> +struct MMA_Traits { - using ValTypeD = float; - using ValTypeA = float_e4m3_t; - using ValTypeB = float_e5m2_t; - using ValTypeC = float; + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_240,_32>; + using Shape_MNK = Shape<_64,_16,_32>; using ThrID = Layout<_128>; using ALayout = GMMA::ALayout_64x32; - using BLayout = GMMA::ABLayout<240, 32>; - using CLayout = GMMA::CLayout_64x240; + using BLayout = GMMA::ABLayout< 16, 32>; + using CLayout = GMMA::CLayout_64x16; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -using SM90_64x256x32_F16E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x256x32_F16E4M3E5M2_SS_TN; +using SM90_64x32x32_S32U8U8_RS_TN = SM90::GMMA::MMA_64x32x32_S32U8U8_RS_TN; -template -struct MMA_Traits> +template <> +struct MMA_Traits { - using ValTypeD = half_t; - using ValTypeA = float_e4m3_t; - using ValTypeB = float_e5m2_t; - using ValTypeC = half_t; + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; - using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_256,_32>; + using Shape_MNK = Shape<_64,_32,_32>; using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 32>; - using BLayout = GMMA::ABLayout<256, 32>; - using CLayout = GMMA::CLayout_64x256; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 32, 32>; + using CLayout = GMMA::CLayout_64x32; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; @@ -15968,27 +5209,23 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -using SM90_64x256x32_F16E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x256x32_F16E4M3E5M2_RS_TN; +using SM90_64x32x32_S32U8U8_RS_TN_SATURATE = SM90::GMMA::MMA_64x32x32_S32U8U8_RS_TN_SATURATE; -template -struct MMA_Traits> +template <> +struct MMA_Traits { - using ValTypeD = half_t; - using ValTypeA = float_e4m3_t; - using ValTypeB = float_e5m2_t; - using ValTypeC = half_t; + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_256,_32>; + using Shape_MNK = Shape<_64,_32,_32>; using ThrID = Layout<_128>; using ALayout = GMMA::ALayout_64x32; - using BLayout = GMMA::ABLayout<256, 32>; - using CLayout = GMMA::CLayout_64x256; + using BLayout = GMMA::ABLayout< 32, 32>; + using CLayout = GMMA::CLayout_64x32; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; @@ -15996,28 +5233,23 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -using SM90_64x256x32_F32E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x256x32_F32E4M3E5M2_SS_TN; +using SM90_64x64x32_S32U8U8_RS_TN = SM90::GMMA::MMA_64x64x32_S32U8U8_RS_TN; -template -struct MMA_Traits> +template <> +struct MMA_Traits { - using ValTypeD = float; - using ValTypeA = float_e4m3_t; - using ValTypeB = float_e5m2_t; - using ValTypeC = float; + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; - using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_256,_32>; + using Shape_MNK = Shape<_64,_64,_32>; using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 32>; - using BLayout = GMMA::ABLayout<256, 32>; - using CLayout = GMMA::CLayout_64x256; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 64, 32>; + using CLayout = GMMA::CLayout_64x64; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; @@ -16025,27 +5257,23 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -using SM90_64x256x32_F32E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x256x32_F32E4M3E5M2_RS_TN; +using SM90_64x64x32_S32U8U8_RS_TN_SATURATE = SM90::GMMA::MMA_64x64x32_S32U8U8_RS_TN_SATURATE; -template -struct MMA_Traits> +template <> +struct MMA_Traits { - using ValTypeD = float; - using ValTypeA = float_e4m3_t; - using ValTypeB = float_e5m2_t; - using ValTypeC = float; + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_256,_32>; + using Shape_MNK = Shape<_64,_64,_32>; using ThrID = Layout<_128>; using ALayout = GMMA::ALayout_64x32; - using BLayout = GMMA::ABLayout<256, 32>; - using CLayout = GMMA::CLayout_64x256; + using BLayout = GMMA::ABLayout< 64, 32>; + using CLayout = GMMA::CLayout_64x64; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; @@ -16053,28 +5281,23 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -using SM90_64x8x32_F16E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x8x32_F16E5M2E4M3_SS_TN; +using SM90_64x96x32_S32U8U8_RS_TN = SM90::GMMA::MMA_64x96x32_S32U8U8_RS_TN; -template -struct MMA_Traits> +template <> +struct MMA_Traits { - using ValTypeD = half_t; - using ValTypeA = float_e5m2_t; - using ValTypeB = float_e4m3_t; - using ValTypeC = half_t; + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; - using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_8,_32>; + using Shape_MNK = Shape<_64,_96,_32>; using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 32>; - using BLayout = GMMA::ABLayout< 8, 32>; - using CLayout = GMMA::CLayout_64x8; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 96, 32>; + using CLayout = GMMA::CLayout_64x96; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; @@ -16082,27 +5305,23 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -using SM90_64x8x32_F16E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x8x32_F16E5M2E4M3_RS_TN; +using SM90_64x96x32_S32U8U8_RS_TN_SATURATE = SM90::GMMA::MMA_64x96x32_S32U8U8_RS_TN_SATURATE; -template -struct MMA_Traits> +template <> +struct MMA_Traits { - using ValTypeD = half_t; - using ValTypeA = float_e5m2_t; - using ValTypeB = float_e4m3_t; - using ValTypeC = half_t; + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_8,_32>; + using Shape_MNK = Shape<_64,_96,_32>; using ThrID = Layout<_128>; using ALayout = GMMA::ALayout_64x32; - using BLayout = GMMA::ABLayout< 8, 32>; - using CLayout = GMMA::CLayout_64x8; + using BLayout = GMMA::ABLayout< 96, 32>; + using CLayout = GMMA::CLayout_64x96; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; @@ -16110,28 +5329,23 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -using SM90_64x8x32_F32E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x8x32_F32E5M2E4M3_SS_TN; +using SM90_64x128x32_S32U8U8_RS_TN = SM90::GMMA::MMA_64x128x32_S32U8U8_RS_TN; -template -struct MMA_Traits> +template <> +struct MMA_Traits { - using ValTypeD = float; - using ValTypeA = float_e5m2_t; - using ValTypeB = float_e4m3_t; - using ValTypeC = float; + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; - using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_8,_32>; + using Shape_MNK = Shape<_64,_128,_32>; using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 32>; - using BLayout = GMMA::ABLayout< 8, 32>; - using CLayout = GMMA::CLayout_64x8; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<128, 32>; + using CLayout = GMMA::CLayout_64x128; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; @@ -16139,27 +5353,23 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -using SM90_64x8x32_F32E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x8x32_F32E5M2E4M3_RS_TN; +using SM90_64x128x32_S32U8U8_RS_TN_SATURATE = SM90::GMMA::MMA_64x128x32_S32U8U8_RS_TN_SATURATE; -template -struct MMA_Traits> +template <> +struct MMA_Traits { - using ValTypeD = float; - using ValTypeA = float_e5m2_t; - using ValTypeB = float_e4m3_t; - using ValTypeC = float; + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_8,_32>; + using Shape_MNK = Shape<_64,_128,_32>; using ThrID = Layout<_128>; using ALayout = GMMA::ALayout_64x32; - using BLayout = GMMA::ABLayout< 8, 32>; - using CLayout = GMMA::CLayout_64x8; + using BLayout = GMMA::ABLayout<128, 32>; + using CLayout = GMMA::CLayout_64x128; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; @@ -16167,28 +5377,23 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -using SM90_64x16x32_F16E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x16x32_F16E5M2E4M3_SS_TN; +using SM90_64x192x32_S32U8U8_RS_TN = SM90::GMMA::MMA_64x192x32_S32U8U8_RS_TN; -template -struct MMA_Traits> +template <> +struct MMA_Traits { - using ValTypeD = half_t; - using ValTypeA = float_e5m2_t; - using ValTypeB = float_e4m3_t; - using ValTypeC = half_t; + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; - using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_16,_32>; + using Shape_MNK = Shape<_64,_192,_32>; using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 32>; - using BLayout = GMMA::ABLayout< 16, 32>; - using CLayout = GMMA::CLayout_64x16; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<192, 32>; + using CLayout = GMMA::CLayout_64x192; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; @@ -16196,27 +5401,23 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -using SM90_64x16x32_F16E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x16x32_F16E5M2E4M3_RS_TN; +using SM90_64x192x32_S32U8U8_RS_TN_SATURATE = SM90::GMMA::MMA_64x192x32_S32U8U8_RS_TN_SATURATE; -template -struct MMA_Traits> +template <> +struct MMA_Traits { - using ValTypeD = half_t; - using ValTypeA = float_e5m2_t; - using ValTypeB = float_e4m3_t; - using ValTypeC = half_t; + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_16,_32>; + using Shape_MNK = Shape<_64,_192,_32>; using ThrID = Layout<_128>; using ALayout = GMMA::ALayout_64x32; - using BLayout = GMMA::ABLayout< 16, 32>; - using CLayout = GMMA::CLayout_64x16; + using BLayout = GMMA::ABLayout<192, 32>; + using CLayout = GMMA::CLayout_64x192; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; @@ -16224,28 +5425,23 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -using SM90_64x16x32_F32E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x16x32_F32E5M2E4M3_SS_TN; +using SM90_64x256x32_S32U8U8_RS_TN = SM90::GMMA::MMA_64x256x32_S32U8U8_RS_TN; -template -struct MMA_Traits> +template <> +struct MMA_Traits { - using ValTypeD = float; - using ValTypeA = float_e5m2_t; - using ValTypeB = float_e4m3_t; - using ValTypeC = float; + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; - using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_16,_32>; + using Shape_MNK = Shape<_64,_256,_32>; using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 32>; - using BLayout = GMMA::ABLayout< 16, 32>; - using CLayout = GMMA::CLayout_64x16; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<256, 32>; + using CLayout = GMMA::CLayout_64x256; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; @@ -16253,517 +5449,480 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// -template < - GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, - GMMA::ScaleIn scaleB = GMMA::ScaleIn::One -> -using SM90_64x16x32_F32E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x16x32_F32E5M2E4M3_RS_TN; +using SM90_64x256x32_S32U8U8_RS_TN_SATURATE = SM90::GMMA::MMA_64x256x32_S32U8U8_RS_TN_SATURATE; -template -struct MMA_Traits> +template <> +struct MMA_Traits { - using ValTypeD = float; - using ValTypeA = float_e5m2_t; - using ValTypeB = float_e4m3_t; - using ValTypeC = float; + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_16,_32>; + using Shape_MNK = Shape<_64,_256,_32>; using ThrID = Layout<_128>; using ALayout = GMMA::ALayout_64x32; - using BLayout = GMMA::ABLayout< 16, 32>; - using CLayout = GMMA::CLayout_64x16; + using BLayout = GMMA::ABLayout<256, 32>; + using CLayout = GMMA::CLayout_64x256; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; //////////////////////////////////////////////////////////////////////////////////////////////////// - template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -using SM90_64x32x32_F16E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x32x32_F16E5M2E4M3_SS_TN; +using SM90_64x8x32_F16E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x8x32_F16E4M3E4M3_SS_TN; template -struct MMA_Traits> +struct MMA_Traits> { using ValTypeD = half_t; - using ValTypeA = float_e5m2_t; + using ValTypeA = float_e4m3_t; using ValTypeB = float_e4m3_t; using ValTypeC = half_t; using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_32,_32>; + using Shape_MNK = Shape<_64,_8,_32>; using ThrID = Layout<_128>; using ALayout = GMMA::ABLayout< 64, 32>; - using BLayout = GMMA::ABLayout< 32, 32>; - using CLayout = GMMA::CLayout_64x32; + using BLayout = GMMA::ABLayout< 8, 32>; + using CLayout = GMMA::CLayout_64x8; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; //////////////////////////////////////////////////////////////////////////////////////////////////// - template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -using SM90_64x32x32_F16E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x32x32_F16E5M2E4M3_RS_TN; +using SM90_64x8x32_F16E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x8x32_F16E4M3E4M3_RS_TN; template -struct MMA_Traits> +struct MMA_Traits> { using ValTypeD = half_t; - using ValTypeA = float_e5m2_t; + using ValTypeA = float_e4m3_t; using ValTypeB = float_e4m3_t; using ValTypeC = half_t; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_32,_32>; + using Shape_MNK = Shape<_64,_8,_32>; using ThrID = Layout<_128>; using ALayout = GMMA::ALayout_64x32; - using BLayout = GMMA::ABLayout< 32, 32>; - using CLayout = GMMA::CLayout_64x32; + using BLayout = GMMA::ABLayout< 8, 32>; + using CLayout = GMMA::CLayout_64x8; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; //////////////////////////////////////////////////////////////////////////////////////////////////// - template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -using SM90_64x32x32_F32E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x32x32_F32E5M2E4M3_SS_TN; +using SM90_64x8x32_F32E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x8x32_F32E4M3E4M3_SS_TN; template -struct MMA_Traits> +struct MMA_Traits> { using ValTypeD = float; - using ValTypeA = float_e5m2_t; + using ValTypeA = float_e4m3_t; using ValTypeB = float_e4m3_t; using ValTypeC = float; using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_32,_32>; + using Shape_MNK = Shape<_64,_8,_32>; using ThrID = Layout<_128>; using ALayout = GMMA::ABLayout< 64, 32>; - using BLayout = GMMA::ABLayout< 32, 32>; - using CLayout = GMMA::CLayout_64x32; + using BLayout = GMMA::ABLayout< 8, 32>; + using CLayout = GMMA::CLayout_64x8; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; //////////////////////////////////////////////////////////////////////////////////////////////////// - template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -using SM90_64x32x32_F32E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x32x32_F32E5M2E4M3_RS_TN; +using SM90_64x8x32_F32E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x8x32_F32E4M3E4M3_RS_TN; template -struct MMA_Traits> +struct MMA_Traits> { using ValTypeD = float; - using ValTypeA = float_e5m2_t; + using ValTypeA = float_e4m3_t; using ValTypeB = float_e4m3_t; using ValTypeC = float; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_32,_32>; + using Shape_MNK = Shape<_64,_8,_32>; using ThrID = Layout<_128>; using ALayout = GMMA::ALayout_64x32; - using BLayout = GMMA::ABLayout< 32, 32>; - using CLayout = GMMA::CLayout_64x32; + using BLayout = GMMA::ABLayout< 8, 32>; + using CLayout = GMMA::CLayout_64x8; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) - template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -using SM90_64x48x32_F16E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x48x32_F16E5M2E4M3_SS_TN; +using SM90_64x16x32_F16E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x16x32_F16E4M3E4M3_SS_TN; template -struct MMA_Traits> +struct MMA_Traits> { using ValTypeD = half_t; - using ValTypeA = float_e5m2_t; + using ValTypeA = float_e4m3_t; using ValTypeB = float_e4m3_t; using ValTypeC = half_t; using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_48,_32>; + using Shape_MNK = Shape<_64,_16,_32>; using ThrID = Layout<_128>; using ALayout = GMMA::ABLayout< 64, 32>; - using BLayout = GMMA::ABLayout< 48, 32>; - using CLayout = GMMA::CLayout_64x48; + using BLayout = GMMA::ABLayout< 16, 32>; + using CLayout = GMMA::CLayout_64x16; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) - template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -using SM90_64x48x32_F16E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x48x32_F16E5M2E4M3_RS_TN; +using SM90_64x16x32_F16E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x16x32_F16E4M3E4M3_RS_TN; template -struct MMA_Traits> +struct MMA_Traits> { using ValTypeD = half_t; - using ValTypeA = float_e5m2_t; + using ValTypeA = float_e4m3_t; using ValTypeB = float_e4m3_t; using ValTypeC = half_t; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_48,_32>; + using Shape_MNK = Shape<_64,_16,_32>; using ThrID = Layout<_128>; using ALayout = GMMA::ALayout_64x32; - using BLayout = GMMA::ABLayout< 48, 32>; - using CLayout = GMMA::CLayout_64x48; + using BLayout = GMMA::ABLayout< 16, 32>; + using CLayout = GMMA::CLayout_64x16; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) - template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -using SM90_64x48x32_F32E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x48x32_F32E5M2E4M3_SS_TN; +using SM90_64x16x32_F32E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x16x32_F32E4M3E4M3_SS_TN; template -struct MMA_Traits> +struct MMA_Traits> { using ValTypeD = float; - using ValTypeA = float_e5m2_t; + using ValTypeA = float_e4m3_t; using ValTypeB = float_e4m3_t; using ValTypeC = float; using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_48,_32>; + using Shape_MNK = Shape<_64,_16,_32>; using ThrID = Layout<_128>; using ALayout = GMMA::ABLayout< 64, 32>; - using BLayout = GMMA::ABLayout< 48, 32>; - using CLayout = GMMA::CLayout_64x48; + using BLayout = GMMA::ABLayout< 16, 32>; + using CLayout = GMMA::CLayout_64x16; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) - template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -using SM90_64x48x32_F32E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x48x32_F32E5M2E4M3_RS_TN; +using SM90_64x16x32_F32E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x16x32_F32E4M3E4M3_RS_TN; template -struct MMA_Traits> +struct MMA_Traits> { using ValTypeD = float; - using ValTypeA = float_e5m2_t; + using ValTypeA = float_e4m3_t; using ValTypeB = float_e4m3_t; using ValTypeC = float; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_48,_32>; + using Shape_MNK = Shape<_64,_16,_32>; using ThrID = Layout<_128>; using ALayout = GMMA::ALayout_64x32; - using BLayout = GMMA::ABLayout< 48, 32>; - using CLayout = GMMA::CLayout_64x48; + using BLayout = GMMA::ABLayout< 16, 32>; + using CLayout = GMMA::CLayout_64x16; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// - template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -using SM90_64x64x32_F16E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x64x32_F16E5M2E4M3_SS_TN; +using SM90_64x32x32_F16E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x32x32_F16E4M3E4M3_SS_TN; template -struct MMA_Traits> +struct MMA_Traits> { using ValTypeD = half_t; - using ValTypeA = float_e5m2_t; + using ValTypeA = float_e4m3_t; using ValTypeB = float_e4m3_t; using ValTypeC = half_t; using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_64,_32>; + using Shape_MNK = Shape<_64,_32,_32>; using ThrID = Layout<_128>; using ALayout = GMMA::ABLayout< 64, 32>; - using BLayout = GMMA::ABLayout< 64, 32>; - using CLayout = GMMA::CLayout_64x64; + using BLayout = GMMA::ABLayout< 32, 32>; + using CLayout = GMMA::CLayout_64x32; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; //////////////////////////////////////////////////////////////////////////////////////////////////// - template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -using SM90_64x64x32_F16E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x64x32_F16E5M2E4M3_RS_TN; +using SM90_64x32x32_F16E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x32x32_F16E4M3E4M3_RS_TN; template -struct MMA_Traits> +struct MMA_Traits> { using ValTypeD = half_t; - using ValTypeA = float_e5m2_t; + using ValTypeA = float_e4m3_t; using ValTypeB = float_e4m3_t; using ValTypeC = half_t; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_64,_32>; + using Shape_MNK = Shape<_64,_32,_32>; using ThrID = Layout<_128>; using ALayout = GMMA::ALayout_64x32; - using BLayout = GMMA::ABLayout< 64, 32>; - using CLayout = GMMA::CLayout_64x64; + using BLayout = GMMA::ABLayout< 32, 32>; + using CLayout = GMMA::CLayout_64x32; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; //////////////////////////////////////////////////////////////////////////////////////////////////// - template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -using SM90_64x64x32_F32E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x64x32_F32E5M2E4M3_SS_TN; +using SM90_64x32x32_F32E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x32x32_F32E4M3E4M3_SS_TN; template -struct MMA_Traits> +struct MMA_Traits> { using ValTypeD = float; - using ValTypeA = float_e5m2_t; + using ValTypeA = float_e4m3_t; using ValTypeB = float_e4m3_t; using ValTypeC = float; using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_64,_32>; + using Shape_MNK = Shape<_64,_32,_32>; using ThrID = Layout<_128>; using ALayout = GMMA::ABLayout< 64, 32>; - using BLayout = GMMA::ABLayout< 64, 32>; - using CLayout = GMMA::CLayout_64x64; + using BLayout = GMMA::ABLayout< 32, 32>; + using CLayout = GMMA::CLayout_64x32; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; //////////////////////////////////////////////////////////////////////////////////////////////////// - template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -using SM90_64x64x32_F32E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x64x32_F32E5M2E4M3_RS_TN; +using SM90_64x32x32_F32E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x32x32_F32E4M3E4M3_RS_TN; template -struct MMA_Traits> +struct MMA_Traits> { using ValTypeD = float; - using ValTypeA = float_e5m2_t; + using ValTypeA = float_e4m3_t; using ValTypeB = float_e4m3_t; using ValTypeC = float; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_64,_32>; + using Shape_MNK = Shape<_64,_32,_32>; using ThrID = Layout<_128>; using ALayout = GMMA::ALayout_64x32; - using BLayout = GMMA::ABLayout< 64, 32>; - using CLayout = GMMA::CLayout_64x64; + using BLayout = GMMA::ABLayout< 32, 32>; + using CLayout = GMMA::CLayout_64x32; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) - template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -using SM90_64x80x32_F16E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x80x32_F16E5M2E4M3_SS_TN; +using SM90_64x64x32_F16E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x64x32_F16E4M3E4M3_SS_TN; template -struct MMA_Traits> +struct MMA_Traits> { using ValTypeD = half_t; - using ValTypeA = float_e5m2_t; + using ValTypeA = float_e4m3_t; using ValTypeB = float_e4m3_t; using ValTypeC = half_t; using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_80,_32>; + using Shape_MNK = Shape<_64,_64,_32>; using ThrID = Layout<_128>; using ALayout = GMMA::ABLayout< 64, 32>; - using BLayout = GMMA::ABLayout< 80, 32>; - using CLayout = GMMA::CLayout_64x80; + using BLayout = GMMA::ABLayout< 64, 32>; + using CLayout = GMMA::CLayout_64x64; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) - template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -using SM90_64x80x32_F16E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x80x32_F16E5M2E4M3_RS_TN; +using SM90_64x64x32_F16E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x64x32_F16E4M3E4M3_RS_TN; template -struct MMA_Traits> +struct MMA_Traits> { using ValTypeD = half_t; - using ValTypeA = float_e5m2_t; + using ValTypeA = float_e4m3_t; using ValTypeB = float_e4m3_t; using ValTypeC = half_t; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_80,_32>; + using Shape_MNK = Shape<_64,_64,_32>; using ThrID = Layout<_128>; using ALayout = GMMA::ALayout_64x32; - using BLayout = GMMA::ABLayout< 80, 32>; - using CLayout = GMMA::CLayout_64x80; + using BLayout = GMMA::ABLayout< 64, 32>; + using CLayout = GMMA::CLayout_64x64; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) - template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -using SM90_64x80x32_F32E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x80x32_F32E5M2E4M3_SS_TN; +using SM90_64x64x32_F32E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x64x32_F32E4M3E4M3_SS_TN; template -struct MMA_Traits> +struct MMA_Traits> { using ValTypeD = float; - using ValTypeA = float_e5m2_t; + using ValTypeA = float_e4m3_t; using ValTypeB = float_e4m3_t; using ValTypeC = float; using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_80,_32>; + using Shape_MNK = Shape<_64,_64,_32>; using ThrID = Layout<_128>; using ALayout = GMMA::ABLayout< 64, 32>; - using BLayout = GMMA::ABLayout< 80, 32>; - using CLayout = GMMA::CLayout_64x80; + using BLayout = GMMA::ABLayout< 64, 32>; + using CLayout = GMMA::CLayout_64x64; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) - template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -using SM90_64x80x32_F32E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x80x32_F32E5M2E4M3_RS_TN; +using SM90_64x64x32_F32E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x64x32_F32E4M3E4M3_RS_TN; template -struct MMA_Traits> +struct MMA_Traits> { using ValTypeD = float; - using ValTypeA = float_e5m2_t; + using ValTypeA = float_e4m3_t; using ValTypeB = float_e4m3_t; using ValTypeC = float; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_80,_32>; + using Shape_MNK = Shape<_64,_64,_32>; using ThrID = Layout<_128>; using ALayout = GMMA::ALayout_64x32; - using BLayout = GMMA::ABLayout< 80, 32>; - using CLayout = GMMA::CLayout_64x80; + using BLayout = GMMA::ABLayout< 64, 32>; + using CLayout = GMMA::CLayout_64x64; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// - template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -using SM90_64x96x32_F16E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x96x32_F16E5M2E4M3_SS_TN; +using SM90_64x96x32_F16E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x96x32_F16E4M3E4M3_SS_TN; template -struct MMA_Traits> +struct MMA_Traits> { using ValTypeD = half_t; - using ValTypeA = float_e5m2_t; + using ValTypeA = float_e4m3_t; using ValTypeB = float_e4m3_t; using ValTypeC = half_t; @@ -16781,18 +5940,17 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// - template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -using SM90_64x96x32_F16E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x96x32_F16E5M2E4M3_RS_TN; +using SM90_64x96x32_F16E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x96x32_F16E4M3E4M3_RS_TN; template -struct MMA_Traits> +struct MMA_Traits> { using ValTypeD = half_t; - using ValTypeA = float_e5m2_t; + using ValTypeA = float_e4m3_t; using ValTypeB = float_e4m3_t; using ValTypeC = half_t; @@ -16809,18 +5967,17 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// - template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -using SM90_64x96x32_F32E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x96x32_F32E5M2E4M3_SS_TN; +using SM90_64x96x32_F32E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x96x32_F32E4M3E4M3_SS_TN; template -struct MMA_Traits> +struct MMA_Traits> { using ValTypeD = float; - using ValTypeA = float_e5m2_t; + using ValTypeA = float_e4m3_t; using ValTypeB = float_e4m3_t; using ValTypeC = float; @@ -16838,18 +5995,17 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// - template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -using SM90_64x96x32_F32E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x96x32_F32E5M2E4M3_RS_TN; +using SM90_64x96x32_F32E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x96x32_F32E4M3E4M3_RS_TN; template -struct MMA_Traits> +struct MMA_Traits> { using ValTypeD = float; - using ValTypeA = float_e5m2_t; + using ValTypeA = float_e4m3_t; using ValTypeB = float_e4m3_t; using ValTypeC = float; @@ -16866,1679 +6022,1558 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) - template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -using SM90_64x112x32_F16E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x112x32_F16E5M2E4M3_SS_TN; +using SM90_64x128x32_F16E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x128x32_F16E4M3E4M3_SS_TN; template -struct MMA_Traits> +struct MMA_Traits> { using ValTypeD = half_t; - using ValTypeA = float_e5m2_t; + using ValTypeA = float_e4m3_t; using ValTypeB = float_e4m3_t; using ValTypeC = half_t; using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_112,_32>; + using Shape_MNK = Shape<_64,_128,_32>; using ThrID = Layout<_128>; using ALayout = GMMA::ABLayout< 64, 32>; - using BLayout = GMMA::ABLayout<112, 32>; - using CLayout = GMMA::CLayout_64x112; + using BLayout = GMMA::ABLayout<128, 32>; + using CLayout = GMMA::CLayout_64x128; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) - template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -using SM90_64x112x32_F16E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x112x32_F16E5M2E4M3_RS_TN; +using SM90_64x128x32_F16E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x128x32_F16E4M3E4M3_RS_TN; template -struct MMA_Traits> +struct MMA_Traits> { using ValTypeD = half_t; - using ValTypeA = float_e5m2_t; + using ValTypeA = float_e4m3_t; using ValTypeB = float_e4m3_t; using ValTypeC = half_t; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_112,_32>; + using Shape_MNK = Shape<_64,_128,_32>; using ThrID = Layout<_128>; using ALayout = GMMA::ALayout_64x32; - using BLayout = GMMA::ABLayout<112, 32>; - using CLayout = GMMA::CLayout_64x112; + using BLayout = GMMA::ABLayout<128, 32>; + using CLayout = GMMA::CLayout_64x128; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) - template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -using SM90_64x112x32_F32E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x112x32_F32E5M2E4M3_SS_TN; +using SM90_64x128x32_F32E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x128x32_F32E4M3E4M3_SS_TN; template -struct MMA_Traits> +struct MMA_Traits> { using ValTypeD = float; - using ValTypeA = float_e5m2_t; + using ValTypeA = float_e4m3_t; using ValTypeB = float_e4m3_t; using ValTypeC = float; using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_112,_32>; + using Shape_MNK = Shape<_64,_128,_32>; using ThrID = Layout<_128>; using ALayout = GMMA::ABLayout< 64, 32>; - using BLayout = GMMA::ABLayout<112, 32>; - using CLayout = GMMA::CLayout_64x112; + using BLayout = GMMA::ABLayout<128, 32>; + using CLayout = GMMA::CLayout_64x128; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) - template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -using SM90_64x112x32_F32E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x112x32_F32E5M2E4M3_RS_TN; +using SM90_64x128x32_F32E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x128x32_F32E4M3E4M3_RS_TN; template -struct MMA_Traits> +struct MMA_Traits> { using ValTypeD = float; - using ValTypeA = float_e5m2_t; + using ValTypeA = float_e4m3_t; using ValTypeB = float_e4m3_t; using ValTypeC = float; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_112,_32>; + using Shape_MNK = Shape<_64,_128,_32>; using ThrID = Layout<_128>; using ALayout = GMMA::ALayout_64x32; - using BLayout = GMMA::ABLayout<112, 32>; - using CLayout = GMMA::CLayout_64x112; + using BLayout = GMMA::ABLayout<128, 32>; + using CLayout = GMMA::CLayout_64x128; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// - template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -using SM90_64x128x32_F16E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x128x32_F16E5M2E4M3_SS_TN; +using SM90_64x192x32_F16E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x192x32_F16E4M3E4M3_SS_TN; template -struct MMA_Traits> +struct MMA_Traits> { using ValTypeD = half_t; - using ValTypeA = float_e5m2_t; + using ValTypeA = float_e4m3_t; using ValTypeB = float_e4m3_t; using ValTypeC = half_t; using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_128,_32>; + using Shape_MNK = Shape<_64,_192,_32>; using ThrID = Layout<_128>; using ALayout = GMMA::ABLayout< 64, 32>; - using BLayout = GMMA::ABLayout<128, 32>; - using CLayout = GMMA::CLayout_64x128; + using BLayout = GMMA::ABLayout<192, 32>; + using CLayout = GMMA::CLayout_64x192; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; //////////////////////////////////////////////////////////////////////////////////////////////////// - template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -using SM90_64x128x32_F16E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x128x32_F16E5M2E4M3_RS_TN; +using SM90_64x192x32_F16E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x192x32_F16E4M3E4M3_RS_TN; template -struct MMA_Traits> +struct MMA_Traits> { using ValTypeD = half_t; - using ValTypeA = float_e5m2_t; + using ValTypeA = float_e4m3_t; using ValTypeB = float_e4m3_t; using ValTypeC = half_t; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_128,_32>; + using Shape_MNK = Shape<_64,_192,_32>; using ThrID = Layout<_128>; using ALayout = GMMA::ALayout_64x32; - using BLayout = GMMA::ABLayout<128, 32>; - using CLayout = GMMA::CLayout_64x128; + using BLayout = GMMA::ABLayout<192, 32>; + using CLayout = GMMA::CLayout_64x192; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; //////////////////////////////////////////////////////////////////////////////////////////////////// - template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -using SM90_64x128x32_F32E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x128x32_F32E5M2E4M3_SS_TN; +using SM90_64x192x32_F32E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x192x32_F32E4M3E4M3_SS_TN; template -struct MMA_Traits> +struct MMA_Traits> { using ValTypeD = float; - using ValTypeA = float_e5m2_t; + using ValTypeA = float_e4m3_t; using ValTypeB = float_e4m3_t; using ValTypeC = float; using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_128,_32>; + using Shape_MNK = Shape<_64,_192,_32>; using ThrID = Layout<_128>; using ALayout = GMMA::ABLayout< 64, 32>; - using BLayout = GMMA::ABLayout<128, 32>; - using CLayout = GMMA::CLayout_64x128; + using BLayout = GMMA::ABLayout<192, 32>; + using CLayout = GMMA::CLayout_64x192; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; //////////////////////////////////////////////////////////////////////////////////////////////////// - template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -using SM90_64x128x32_F32E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x128x32_F32E5M2E4M3_RS_TN; +using SM90_64x192x32_F32E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x192x32_F32E4M3E4M3_RS_TN; template -struct MMA_Traits> +struct MMA_Traits> { using ValTypeD = float; - using ValTypeA = float_e5m2_t; + using ValTypeA = float_e4m3_t; using ValTypeB = float_e4m3_t; using ValTypeC = float; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_128,_32>; + using Shape_MNK = Shape<_64,_192,_32>; using ThrID = Layout<_128>; using ALayout = GMMA::ALayout_64x32; - using BLayout = GMMA::ABLayout<128, 32>; - using CLayout = GMMA::CLayout_64x128; + using BLayout = GMMA::ABLayout<192, 32>; + using CLayout = GMMA::CLayout_64x192; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) - template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -using SM90_64x144x32_F16E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x144x32_F16E5M2E4M3_SS_TN; +using SM90_64x256x32_F16E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x256x32_F16E4M3E4M3_SS_TN; template -struct MMA_Traits> +struct MMA_Traits> { using ValTypeD = half_t; - using ValTypeA = float_e5m2_t; + using ValTypeA = float_e4m3_t; using ValTypeB = float_e4m3_t; using ValTypeC = half_t; using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_144,_32>; + using Shape_MNK = Shape<_64,_256,_32>; using ThrID = Layout<_128>; using ALayout = GMMA::ABLayout< 64, 32>; - using BLayout = GMMA::ABLayout<144, 32>; - using CLayout = GMMA::CLayout_64x144; + using BLayout = GMMA::ABLayout<256, 32>; + using CLayout = GMMA::CLayout_64x256; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) - template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -using SM90_64x144x32_F16E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x144x32_F16E5M2E4M3_RS_TN; +using SM90_64x256x32_F16E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x256x32_F16E4M3E4M3_RS_TN; template -struct MMA_Traits> +struct MMA_Traits> { using ValTypeD = half_t; - using ValTypeA = float_e5m2_t; + using ValTypeA = float_e4m3_t; using ValTypeB = float_e4m3_t; using ValTypeC = half_t; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_144,_32>; + using Shape_MNK = Shape<_64,_256,_32>; using ThrID = Layout<_128>; using ALayout = GMMA::ALayout_64x32; - using BLayout = GMMA::ABLayout<144, 32>; - using CLayout = GMMA::CLayout_64x144; + using BLayout = GMMA::ABLayout<256, 32>; + using CLayout = GMMA::CLayout_64x256; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) - template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -using SM90_64x144x32_F32E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x144x32_F32E5M2E4M3_SS_TN; +using SM90_64x256x32_F32E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x256x32_F32E4M3E4M3_SS_TN; template -struct MMA_Traits> +struct MMA_Traits> { using ValTypeD = float; - using ValTypeA = float_e5m2_t; + using ValTypeA = float_e4m3_t; using ValTypeB = float_e4m3_t; using ValTypeC = float; using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_144,_32>; + using Shape_MNK = Shape<_64,_256,_32>; using ThrID = Layout<_128>; using ALayout = GMMA::ABLayout< 64, 32>; - using BLayout = GMMA::ABLayout<144, 32>; - using CLayout = GMMA::CLayout_64x144; + using BLayout = GMMA::ABLayout<256, 32>; + using CLayout = GMMA::CLayout_64x256; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) - template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -using SM90_64x144x32_F32E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x144x32_F32E5M2E4M3_RS_TN; +using SM90_64x256x32_F32E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x256x32_F32E4M3E4M3_RS_TN; template -struct MMA_Traits> +struct MMA_Traits> { using ValTypeD = float; - using ValTypeA = float_e5m2_t; + using ValTypeA = float_e4m3_t; using ValTypeB = float_e4m3_t; using ValTypeC = float; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_144,_32>; + using Shape_MNK = Shape<_64,_256,_32>; using ThrID = Layout<_128>; using ALayout = GMMA::ALayout_64x32; - using BLayout = GMMA::ABLayout<144, 32>; - using CLayout = GMMA::CLayout_64x144; + using BLayout = GMMA::ABLayout<256, 32>; + using CLayout = GMMA::CLayout_64x256; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) - template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -using SM90_64x160x32_F16E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x160x32_F16E5M2E4M3_SS_TN; +using SM90_64x8x32_F16E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x8x32_F16E4M3E5M2_SS_TN; template -struct MMA_Traits> +struct MMA_Traits> { using ValTypeD = half_t; - using ValTypeA = float_e5m2_t; - using ValTypeB = float_e4m3_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; using ValTypeC = half_t; using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_160,_32>; + using Shape_MNK = Shape<_64,_8,_32>; using ThrID = Layout<_128>; using ALayout = GMMA::ABLayout< 64, 32>; - using BLayout = GMMA::ABLayout<160, 32>; - using CLayout = GMMA::CLayout_64x160; + using BLayout = GMMA::ABLayout< 8, 32>; + using CLayout = GMMA::CLayout_64x8; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) - template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -using SM90_64x160x32_F16E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x160x32_F16E5M2E4M3_RS_TN; +using SM90_64x8x32_F16E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x8x32_F16E4M3E5M2_RS_TN; template -struct MMA_Traits> +struct MMA_Traits> { using ValTypeD = half_t; - using ValTypeA = float_e5m2_t; - using ValTypeB = float_e4m3_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; using ValTypeC = half_t; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_160,_32>; + using Shape_MNK = Shape<_64,_8,_32>; using ThrID = Layout<_128>; using ALayout = GMMA::ALayout_64x32; - using BLayout = GMMA::ABLayout<160, 32>; - using CLayout = GMMA::CLayout_64x160; + using BLayout = GMMA::ABLayout< 8, 32>; + using CLayout = GMMA::CLayout_64x8; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) - template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -using SM90_64x160x32_F32E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x160x32_F32E5M2E4M3_SS_TN; +using SM90_64x8x32_F32E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x8x32_F32E4M3E5M2_SS_TN; template -struct MMA_Traits> +struct MMA_Traits> { using ValTypeD = float; - using ValTypeA = float_e5m2_t; - using ValTypeB = float_e4m3_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; using ValTypeC = float; using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_160,_32>; + using Shape_MNK = Shape<_64,_8,_32>; using ThrID = Layout<_128>; using ALayout = GMMA::ABLayout< 64, 32>; - using BLayout = GMMA::ABLayout<160, 32>; - using CLayout = GMMA::CLayout_64x160; + using BLayout = GMMA::ABLayout< 8, 32>; + using CLayout = GMMA::CLayout_64x8; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) - template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -using SM90_64x160x32_F32E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x160x32_F32E5M2E4M3_RS_TN; +using SM90_64x8x32_F32E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x8x32_F32E4M3E5M2_RS_TN; template -struct MMA_Traits> +struct MMA_Traits> { using ValTypeD = float; - using ValTypeA = float_e5m2_t; - using ValTypeB = float_e4m3_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; using ValTypeC = float; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_160,_32>; + using Shape_MNK = Shape<_64,_8,_32>; using ThrID = Layout<_128>; using ALayout = GMMA::ALayout_64x32; - using BLayout = GMMA::ABLayout<160, 32>; - using CLayout = GMMA::CLayout_64x160; + using BLayout = GMMA::ABLayout< 8, 32>; + using CLayout = GMMA::CLayout_64x8; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) - template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -using SM90_64x176x32_F16E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x176x32_F16E5M2E4M3_SS_TN; +using SM90_64x16x32_F16E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x16x32_F16E4M3E5M2_SS_TN; template -struct MMA_Traits> +struct MMA_Traits> { using ValTypeD = half_t; - using ValTypeA = float_e5m2_t; - using ValTypeB = float_e4m3_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; using ValTypeC = half_t; using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_176,_32>; + using Shape_MNK = Shape<_64,_16,_32>; using ThrID = Layout<_128>; using ALayout = GMMA::ABLayout< 64, 32>; - using BLayout = GMMA::ABLayout<176, 32>; - using CLayout = GMMA::CLayout_64x176; + using BLayout = GMMA::ABLayout< 16, 32>; + using CLayout = GMMA::CLayout_64x16; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) - template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -using SM90_64x176x32_F16E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x176x32_F16E5M2E4M3_RS_TN; +using SM90_64x16x32_F16E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x16x32_F16E4M3E5M2_RS_TN; template -struct MMA_Traits> +struct MMA_Traits> { using ValTypeD = half_t; - using ValTypeA = float_e5m2_t; - using ValTypeB = float_e4m3_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; using ValTypeC = half_t; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_176,_32>; + using Shape_MNK = Shape<_64,_16,_32>; using ThrID = Layout<_128>; using ALayout = GMMA::ALayout_64x32; - using BLayout = GMMA::ABLayout<176, 32>; - using CLayout = GMMA::CLayout_64x176; + using BLayout = GMMA::ABLayout< 16, 32>; + using CLayout = GMMA::CLayout_64x16; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) - template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -using SM90_64x176x32_F32E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x176x32_F32E5M2E4M3_SS_TN; +using SM90_64x16x32_F32E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x16x32_F32E4M3E5M2_SS_TN; template -struct MMA_Traits> +struct MMA_Traits> { using ValTypeD = float; - using ValTypeA = float_e5m2_t; - using ValTypeB = float_e4m3_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; using ValTypeC = float; using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_176,_32>; + using Shape_MNK = Shape<_64,_16,_32>; using ThrID = Layout<_128>; using ALayout = GMMA::ABLayout< 64, 32>; - using BLayout = GMMA::ABLayout<176, 32>; - using CLayout = GMMA::CLayout_64x176; + using BLayout = GMMA::ABLayout< 16, 32>; + using CLayout = GMMA::CLayout_64x16; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) - template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -using SM90_64x176x32_F32E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x176x32_F32E5M2E4M3_RS_TN; +using SM90_64x16x32_F32E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x16x32_F32E4M3E5M2_RS_TN; template -struct MMA_Traits> +struct MMA_Traits> { using ValTypeD = float; - using ValTypeA = float_e5m2_t; - using ValTypeB = float_e4m3_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; using ValTypeC = float; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_176,_32>; + using Shape_MNK = Shape<_64,_16,_32>; using ThrID = Layout<_128>; using ALayout = GMMA::ALayout_64x32; - using BLayout = GMMA::ABLayout<176, 32>; - using CLayout = GMMA::CLayout_64x176; + using BLayout = GMMA::ABLayout< 16, 32>; + using CLayout = GMMA::CLayout_64x16; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// - template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -using SM90_64x192x32_F16E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x192x32_F16E5M2E4M3_SS_TN; +using SM90_64x32x32_F16E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x32x32_F16E4M3E5M2_SS_TN; template -struct MMA_Traits> +struct MMA_Traits> { using ValTypeD = half_t; - using ValTypeA = float_e5m2_t; - using ValTypeB = float_e4m3_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; using ValTypeC = half_t; using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_192,_32>; + using Shape_MNK = Shape<_64,_32,_32>; using ThrID = Layout<_128>; using ALayout = GMMA::ABLayout< 64, 32>; - using BLayout = GMMA::ABLayout<192, 32>; - using CLayout = GMMA::CLayout_64x192; + using BLayout = GMMA::ABLayout< 32, 32>; + using CLayout = GMMA::CLayout_64x32; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; //////////////////////////////////////////////////////////////////////////////////////////////////// - template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -using SM90_64x192x32_F16E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x192x32_F16E5M2E4M3_RS_TN; +using SM90_64x32x32_F16E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x32x32_F16E4M3E5M2_RS_TN; template -struct MMA_Traits> +struct MMA_Traits> { using ValTypeD = half_t; - using ValTypeA = float_e5m2_t; - using ValTypeB = float_e4m3_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; using ValTypeC = half_t; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_192,_32>; + using Shape_MNK = Shape<_64,_32,_32>; using ThrID = Layout<_128>; using ALayout = GMMA::ALayout_64x32; - using BLayout = GMMA::ABLayout<192, 32>; - using CLayout = GMMA::CLayout_64x192; + using BLayout = GMMA::ABLayout< 32, 32>; + using CLayout = GMMA::CLayout_64x32; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; //////////////////////////////////////////////////////////////////////////////////////////////////// - template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -using SM90_64x192x32_F32E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x192x32_F32E5M2E4M3_SS_TN; +using SM90_64x32x32_F32E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x32x32_F32E4M3E5M2_SS_TN; template -struct MMA_Traits> +struct MMA_Traits> { using ValTypeD = float; - using ValTypeA = float_e5m2_t; - using ValTypeB = float_e4m3_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; using ValTypeC = float; using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_192,_32>; + using Shape_MNK = Shape<_64,_32,_32>; using ThrID = Layout<_128>; using ALayout = GMMA::ABLayout< 64, 32>; - using BLayout = GMMA::ABLayout<192, 32>; - using CLayout = GMMA::CLayout_64x192; + using BLayout = GMMA::ABLayout< 32, 32>; + using CLayout = GMMA::CLayout_64x32; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; //////////////////////////////////////////////////////////////////////////////////////////////////// - template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -using SM90_64x192x32_F32E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x192x32_F32E5M2E4M3_RS_TN; +using SM90_64x32x32_F32E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x32x32_F32E4M3E5M2_RS_TN; template -struct MMA_Traits> +struct MMA_Traits> { using ValTypeD = float; - using ValTypeA = float_e5m2_t; - using ValTypeB = float_e4m3_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; using ValTypeC = float; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_192,_32>; + using Shape_MNK = Shape<_64,_32,_32>; using ThrID = Layout<_128>; using ALayout = GMMA::ALayout_64x32; - using BLayout = GMMA::ABLayout<192, 32>; - using CLayout = GMMA::CLayout_64x192; + using BLayout = GMMA::ABLayout< 32, 32>; + using CLayout = GMMA::CLayout_64x32; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) - template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -using SM90_64x208x32_F16E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x208x32_F16E5M2E4M3_SS_TN; +using SM90_64x64x32_F16E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x64x32_F16E4M3E5M2_SS_TN; template -struct MMA_Traits> +struct MMA_Traits> { using ValTypeD = half_t; - using ValTypeA = float_e5m2_t; - using ValTypeB = float_e4m3_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; using ValTypeC = half_t; using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_208,_32>; + using Shape_MNK = Shape<_64,_64,_32>; using ThrID = Layout<_128>; using ALayout = GMMA::ABLayout< 64, 32>; - using BLayout = GMMA::ABLayout<208, 32>; - using CLayout = GMMA::CLayout_64x208; + using BLayout = GMMA::ABLayout< 64, 32>; + using CLayout = GMMA::CLayout_64x64; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) - template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -using SM90_64x208x32_F16E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x208x32_F16E5M2E4M3_RS_TN; +using SM90_64x64x32_F16E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x64x32_F16E4M3E5M2_RS_TN; template -struct MMA_Traits> +struct MMA_Traits> { using ValTypeD = half_t; - using ValTypeA = float_e5m2_t; - using ValTypeB = float_e4m3_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; using ValTypeC = half_t; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_208,_32>; + using Shape_MNK = Shape<_64,_64,_32>; using ThrID = Layout<_128>; using ALayout = GMMA::ALayout_64x32; - using BLayout = GMMA::ABLayout<208, 32>; - using CLayout = GMMA::CLayout_64x208; + using BLayout = GMMA::ABLayout< 64, 32>; + using CLayout = GMMA::CLayout_64x64; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) - template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -using SM90_64x208x32_F32E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x208x32_F32E5M2E4M3_SS_TN; +using SM90_64x64x32_F32E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x64x32_F32E4M3E5M2_SS_TN; template -struct MMA_Traits> +struct MMA_Traits> { using ValTypeD = float; - using ValTypeA = float_e5m2_t; - using ValTypeB = float_e4m3_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; using ValTypeC = float; using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_208,_32>; + using Shape_MNK = Shape<_64,_64,_32>; using ThrID = Layout<_128>; using ALayout = GMMA::ABLayout< 64, 32>; - using BLayout = GMMA::ABLayout<208, 32>; - using CLayout = GMMA::CLayout_64x208; + using BLayout = GMMA::ABLayout< 64, 32>; + using CLayout = GMMA::CLayout_64x64; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) - template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -using SM90_64x208x32_F32E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x208x32_F32E5M2E4M3_RS_TN; +using SM90_64x64x32_F32E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x64x32_F32E4M3E5M2_RS_TN; template -struct MMA_Traits> +struct MMA_Traits> { using ValTypeD = float; - using ValTypeA = float_e5m2_t; - using ValTypeB = float_e4m3_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; using ValTypeC = float; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_208,_32>; + using Shape_MNK = Shape<_64,_64,_32>; using ThrID = Layout<_128>; using ALayout = GMMA::ALayout_64x32; - using BLayout = GMMA::ABLayout<208, 32>; - using CLayout = GMMA::CLayout_64x208; + using BLayout = GMMA::ABLayout< 64, 32>; + using CLayout = GMMA::CLayout_64x64; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) - template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -using SM90_64x224x32_F16E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x224x32_F16E5M2E4M3_SS_TN; +using SM90_64x96x32_F16E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x96x32_F16E4M3E5M2_SS_TN; template -struct MMA_Traits> +struct MMA_Traits> { using ValTypeD = half_t; - using ValTypeA = float_e5m2_t; - using ValTypeB = float_e4m3_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; using ValTypeC = half_t; using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_224,_32>; + using Shape_MNK = Shape<_64,_96,_32>; using ThrID = Layout<_128>; using ALayout = GMMA::ABLayout< 64, 32>; - using BLayout = GMMA::ABLayout<224, 32>; - using CLayout = GMMA::CLayout_64x224; + using BLayout = GMMA::ABLayout< 96, 32>; + using CLayout = GMMA::CLayout_64x96; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) - template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -using SM90_64x224x32_F16E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x224x32_F16E5M2E4M3_RS_TN; +using SM90_64x96x32_F16E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x96x32_F16E4M3E5M2_RS_TN; template -struct MMA_Traits> +struct MMA_Traits> { using ValTypeD = half_t; - using ValTypeA = float_e5m2_t; - using ValTypeB = float_e4m3_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; using ValTypeC = half_t; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_224,_32>; + using Shape_MNK = Shape<_64,_96,_32>; using ThrID = Layout<_128>; using ALayout = GMMA::ALayout_64x32; - using BLayout = GMMA::ABLayout<224, 32>; - using CLayout = GMMA::CLayout_64x224; + using BLayout = GMMA::ABLayout< 96, 32>; + using CLayout = GMMA::CLayout_64x96; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) - template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -using SM90_64x224x32_F32E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x224x32_F32E5M2E4M3_SS_TN; +using SM90_64x96x32_F32E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x96x32_F32E4M3E5M2_SS_TN; template -struct MMA_Traits> +struct MMA_Traits> { using ValTypeD = float; - using ValTypeA = float_e5m2_t; - using ValTypeB = float_e4m3_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; using ValTypeC = float; using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_224,_32>; + using Shape_MNK = Shape<_64,_96,_32>; using ThrID = Layout<_128>; using ALayout = GMMA::ABLayout< 64, 32>; - using BLayout = GMMA::ABLayout<224, 32>; - using CLayout = GMMA::CLayout_64x224; + using BLayout = GMMA::ABLayout< 96, 32>; + using CLayout = GMMA::CLayout_64x96; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) - template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -using SM90_64x224x32_F32E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x224x32_F32E5M2E4M3_RS_TN; +using SM90_64x96x32_F32E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x96x32_F32E4M3E5M2_RS_TN; template -struct MMA_Traits> +struct MMA_Traits> { using ValTypeD = float; - using ValTypeA = float_e5m2_t; - using ValTypeB = float_e4m3_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; using ValTypeC = float; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_224,_32>; + using Shape_MNK = Shape<_64,_96,_32>; using ThrID = Layout<_128>; using ALayout = GMMA::ALayout_64x32; - using BLayout = GMMA::ABLayout<224, 32>; - using CLayout = GMMA::CLayout_64x224; + using BLayout = GMMA::ABLayout< 96, 32>; + using CLayout = GMMA::CLayout_64x96; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) - template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -using SM90_64x240x32_F16E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x240x32_F16E5M2E4M3_SS_TN; +using SM90_64x128x32_F16E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x128x32_F16E4M3E5M2_SS_TN; template -struct MMA_Traits> +struct MMA_Traits> { using ValTypeD = half_t; - using ValTypeA = float_e5m2_t; - using ValTypeB = float_e4m3_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; using ValTypeC = half_t; using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_240,_32>; + using Shape_MNK = Shape<_64,_128,_32>; using ThrID = Layout<_128>; using ALayout = GMMA::ABLayout< 64, 32>; - using BLayout = GMMA::ABLayout<240, 32>; - using CLayout = GMMA::CLayout_64x240; + using BLayout = GMMA::ABLayout<128, 32>; + using CLayout = GMMA::CLayout_64x128; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) - template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -using SM90_64x240x32_F16E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x240x32_F16E5M2E4M3_RS_TN; +using SM90_64x128x32_F16E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x128x32_F16E4M3E5M2_RS_TN; template -struct MMA_Traits> +struct MMA_Traits> { using ValTypeD = half_t; - using ValTypeA = float_e5m2_t; - using ValTypeB = float_e4m3_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; using ValTypeC = half_t; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_240,_32>; + using Shape_MNK = Shape<_64,_128,_32>; using ThrID = Layout<_128>; using ALayout = GMMA::ALayout_64x32; - using BLayout = GMMA::ABLayout<240, 32>; - using CLayout = GMMA::CLayout_64x240; + using BLayout = GMMA::ABLayout<128, 32>; + using CLayout = GMMA::CLayout_64x128; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) - template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -using SM90_64x240x32_F32E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x240x32_F32E5M2E4M3_SS_TN; +using SM90_64x128x32_F32E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x128x32_F32E4M3E5M2_SS_TN; template -struct MMA_Traits> +struct MMA_Traits> { using ValTypeD = float; - using ValTypeA = float_e5m2_t; - using ValTypeB = float_e4m3_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; using ValTypeC = float; using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_240,_32>; + using Shape_MNK = Shape<_64,_128,_32>; using ThrID = Layout<_128>; using ALayout = GMMA::ABLayout< 64, 32>; - using BLayout = GMMA::ABLayout<240, 32>; - using CLayout = GMMA::CLayout_64x240; + using BLayout = GMMA::ABLayout<128, 32>; + using CLayout = GMMA::CLayout_64x128; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) - template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -using SM90_64x240x32_F32E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x240x32_F32E5M2E4M3_RS_TN; +using SM90_64x128x32_F32E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x128x32_F32E4M3E5M2_RS_TN; template -struct MMA_Traits> +struct MMA_Traits> { using ValTypeD = float; - using ValTypeA = float_e5m2_t; - using ValTypeB = float_e4m3_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; using ValTypeC = float; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_240,_32>; + using Shape_MNK = Shape<_64,_128,_32>; using ThrID = Layout<_128>; using ALayout = GMMA::ALayout_64x32; - using BLayout = GMMA::ABLayout<240, 32>; - using CLayout = GMMA::CLayout_64x240; + using BLayout = GMMA::ABLayout<128, 32>; + using CLayout = GMMA::CLayout_64x128; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// - template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -using SM90_64x256x32_F16E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x256x32_F16E5M2E4M3_SS_TN; +using SM90_64x192x32_F16E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x192x32_F16E4M3E5M2_SS_TN; template -struct MMA_Traits> +struct MMA_Traits> { using ValTypeD = half_t; - using ValTypeA = float_e5m2_t; - using ValTypeB = float_e4m3_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; using ValTypeC = half_t; using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_256,_32>; + using Shape_MNK = Shape<_64,_192,_32>; using ThrID = Layout<_128>; using ALayout = GMMA::ABLayout< 64, 32>; - using BLayout = GMMA::ABLayout<256, 32>; - using CLayout = GMMA::CLayout_64x256; + using BLayout = GMMA::ABLayout<192, 32>; + using CLayout = GMMA::CLayout_64x192; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; //////////////////////////////////////////////////////////////////////////////////////////////////// - template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -using SM90_64x256x32_F16E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x256x32_F16E5M2E4M3_RS_TN; +using SM90_64x192x32_F16E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x192x32_F16E4M3E5M2_RS_TN; template -struct MMA_Traits> +struct MMA_Traits> { using ValTypeD = half_t; - using ValTypeA = float_e5m2_t; - using ValTypeB = float_e4m3_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; using ValTypeC = half_t; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_256,_32>; + using Shape_MNK = Shape<_64,_192,_32>; using ThrID = Layout<_128>; using ALayout = GMMA::ALayout_64x32; - using BLayout = GMMA::ABLayout<256, 32>; - using CLayout = GMMA::CLayout_64x256; + using BLayout = GMMA::ABLayout<192, 32>; + using CLayout = GMMA::CLayout_64x192; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; //////////////////////////////////////////////////////////////////////////////////////////////////// - template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -using SM90_64x256x32_F32E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x256x32_F32E5M2E4M3_SS_TN; +using SM90_64x192x32_F32E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x192x32_F32E4M3E5M2_SS_TN; template -struct MMA_Traits> +struct MMA_Traits> { using ValTypeD = float; - using ValTypeA = float_e5m2_t; - using ValTypeB = float_e4m3_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; using ValTypeC = float; using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_256,_32>; + using Shape_MNK = Shape<_64,_192,_32>; using ThrID = Layout<_128>; using ALayout = GMMA::ABLayout< 64, 32>; - using BLayout = GMMA::ABLayout<256, 32>; - using CLayout = GMMA::CLayout_64x256; + using BLayout = GMMA::ABLayout<192, 32>; + using CLayout = GMMA::CLayout_64x192; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; //////////////////////////////////////////////////////////////////////////////////////////////////// - template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -using SM90_64x256x32_F32E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x256x32_F32E5M2E4M3_RS_TN; +using SM90_64x192x32_F32E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x192x32_F32E4M3E5M2_RS_TN; template -struct MMA_Traits> +struct MMA_Traits> { using ValTypeD = float; - using ValTypeA = float_e5m2_t; - using ValTypeB = float_e4m3_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; using ValTypeC = float; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_256,_32>; + using Shape_MNK = Shape<_64,_192,_32>; using ThrID = Layout<_128>; using ALayout = GMMA::ALayout_64x32; - using BLayout = GMMA::ABLayout<256, 32>; - using CLayout = GMMA::CLayout_64x256; + using BLayout = GMMA::ABLayout<192, 32>; + using CLayout = GMMA::CLayout_64x192; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; //////////////////////////////////////////////////////////////////////////////////////////////////// - template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -using SM90_64x8x32_F16E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x8x32_F16E5M2E5M2_SS_TN; +using SM90_64x256x32_F16E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x256x32_F16E4M3E5M2_SS_TN; template -struct MMA_Traits> +struct MMA_Traits> { using ValTypeD = half_t; - using ValTypeA = float_e5m2_t; + using ValTypeA = float_e4m3_t; using ValTypeB = float_e5m2_t; using ValTypeC = half_t; using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_8,_32>; + using Shape_MNK = Shape<_64,_256,_32>; using ThrID = Layout<_128>; using ALayout = GMMA::ABLayout< 64, 32>; - using BLayout = GMMA::ABLayout< 8, 32>; - using CLayout = GMMA::CLayout_64x8; + using BLayout = GMMA::ABLayout<256, 32>; + using CLayout = GMMA::CLayout_64x256; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; //////////////////////////////////////////////////////////////////////////////////////////////////// - template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -using SM90_64x8x32_F16E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x8x32_F16E5M2E5M2_RS_TN; +using SM90_64x256x32_F16E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x256x32_F16E4M3E5M2_RS_TN; template -struct MMA_Traits> +struct MMA_Traits> { using ValTypeD = half_t; - using ValTypeA = float_e5m2_t; + using ValTypeA = float_e4m3_t; using ValTypeB = float_e5m2_t; using ValTypeC = half_t; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_8,_32>; + using Shape_MNK = Shape<_64,_256,_32>; using ThrID = Layout<_128>; using ALayout = GMMA::ALayout_64x32; - using BLayout = GMMA::ABLayout< 8, 32>; - using CLayout = GMMA::CLayout_64x8; + using BLayout = GMMA::ABLayout<256, 32>; + using CLayout = GMMA::CLayout_64x256; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; //////////////////////////////////////////////////////////////////////////////////////////////////// - template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -using SM90_64x8x32_F32E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x8x32_F32E5M2E5M2_SS_TN; +using SM90_64x256x32_F32E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x256x32_F32E4M3E5M2_SS_TN; template -struct MMA_Traits> +struct MMA_Traits> { using ValTypeD = float; - using ValTypeA = float_e5m2_t; + using ValTypeA = float_e4m3_t; using ValTypeB = float_e5m2_t; using ValTypeC = float; using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_8,_32>; + using Shape_MNK = Shape<_64,_256,_32>; using ThrID = Layout<_128>; using ALayout = GMMA::ABLayout< 64, 32>; - using BLayout = GMMA::ABLayout< 8, 32>; - using CLayout = GMMA::CLayout_64x8; + using BLayout = GMMA::ABLayout<256, 32>; + using CLayout = GMMA::CLayout_64x256; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; //////////////////////////////////////////////////////////////////////////////////////////////////// - template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -using SM90_64x8x32_F32E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x8x32_F32E5M2E5M2_RS_TN; +using SM90_64x256x32_F32E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x256x32_F32E4M3E5M2_RS_TN; template -struct MMA_Traits> +struct MMA_Traits> { using ValTypeD = float; - using ValTypeA = float_e5m2_t; + using ValTypeA = float_e4m3_t; using ValTypeB = float_e5m2_t; using ValTypeC = float; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_8,_32>; + using Shape_MNK = Shape<_64,_256,_32>; using ThrID = Layout<_128>; using ALayout = GMMA::ALayout_64x32; - using BLayout = GMMA::ABLayout< 8, 32>; - using CLayout = GMMA::CLayout_64x8; + using BLayout = GMMA::ABLayout<256, 32>; + using CLayout = GMMA::CLayout_64x256; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; //////////////////////////////////////////////////////////////////////////////////////////////////// - template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -using SM90_64x16x32_F16E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x16x32_F16E5M2E5M2_SS_TN; +using SM90_64x8x32_F16E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x8x32_F16E5M2E4M3_SS_TN; template -struct MMA_Traits> +struct MMA_Traits> { using ValTypeD = half_t; using ValTypeA = float_e5m2_t; - using ValTypeB = float_e5m2_t; + using ValTypeB = float_e4m3_t; using ValTypeC = half_t; using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_16,_32>; + using Shape_MNK = Shape<_64,_8,_32>; using ThrID = Layout<_128>; using ALayout = GMMA::ABLayout< 64, 32>; - using BLayout = GMMA::ABLayout< 16, 32>; - using CLayout = GMMA::CLayout_64x16; + using BLayout = GMMA::ABLayout< 8, 32>; + using CLayout = GMMA::CLayout_64x8; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; //////////////////////////////////////////////////////////////////////////////////////////////////// - template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -using SM90_64x16x32_F16E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x16x32_F16E5M2E5M2_RS_TN; +using SM90_64x8x32_F16E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x8x32_F16E5M2E4M3_RS_TN; template -struct MMA_Traits> +struct MMA_Traits> { using ValTypeD = half_t; using ValTypeA = float_e5m2_t; - using ValTypeB = float_e5m2_t; + using ValTypeB = float_e4m3_t; using ValTypeC = half_t; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_16,_32>; + using Shape_MNK = Shape<_64,_8,_32>; using ThrID = Layout<_128>; using ALayout = GMMA::ALayout_64x32; - using BLayout = GMMA::ABLayout< 16, 32>; - using CLayout = GMMA::CLayout_64x16; + using BLayout = GMMA::ABLayout< 8, 32>; + using CLayout = GMMA::CLayout_64x8; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; //////////////////////////////////////////////////////////////////////////////////////////////////// - template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -using SM90_64x16x32_F32E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x16x32_F32E5M2E5M2_SS_TN; +using SM90_64x8x32_F32E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x8x32_F32E5M2E4M3_SS_TN; template -struct MMA_Traits> +struct MMA_Traits> { using ValTypeD = float; using ValTypeA = float_e5m2_t; - using ValTypeB = float_e5m2_t; + using ValTypeB = float_e4m3_t; using ValTypeC = float; using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_16,_32>; + using Shape_MNK = Shape<_64,_8,_32>; using ThrID = Layout<_128>; using ALayout = GMMA::ABLayout< 64, 32>; - using BLayout = GMMA::ABLayout< 16, 32>; - using CLayout = GMMA::CLayout_64x16; + using BLayout = GMMA::ABLayout< 8, 32>; + using CLayout = GMMA::CLayout_64x8; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; //////////////////////////////////////////////////////////////////////////////////////////////////// - template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -using SM90_64x16x32_F32E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x16x32_F32E5M2E5M2_RS_TN; +using SM90_64x8x32_F32E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x8x32_F32E5M2E4M3_RS_TN; template -struct MMA_Traits> +struct MMA_Traits> { using ValTypeD = float; using ValTypeA = float_e5m2_t; - using ValTypeB = float_e5m2_t; + using ValTypeB = float_e4m3_t; using ValTypeC = float; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_16,_32>; + using Shape_MNK = Shape<_64,_8,_32>; using ThrID = Layout<_128>; using ALayout = GMMA::ALayout_64x32; - using BLayout = GMMA::ABLayout< 16, 32>; - using CLayout = GMMA::CLayout_64x16; + using BLayout = GMMA::ABLayout< 8, 32>; + using CLayout = GMMA::CLayout_64x8; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; //////////////////////////////////////////////////////////////////////////////////////////////////// - template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -using SM90_64x32x32_F16E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x32x32_F16E5M2E5M2_SS_TN; +using SM90_64x16x32_F16E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x16x32_F16E5M2E4M3_SS_TN; template -struct MMA_Traits> +struct MMA_Traits> { using ValTypeD = half_t; using ValTypeA = float_e5m2_t; - using ValTypeB = float_e5m2_t; + using ValTypeB = float_e4m3_t; using ValTypeC = half_t; using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_32,_32>; + using Shape_MNK = Shape<_64,_16,_32>; using ThrID = Layout<_128>; using ALayout = GMMA::ABLayout< 64, 32>; - using BLayout = GMMA::ABLayout< 32, 32>; - using CLayout = GMMA::CLayout_64x32; + using BLayout = GMMA::ABLayout< 16, 32>; + using CLayout = GMMA::CLayout_64x16; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; //////////////////////////////////////////////////////////////////////////////////////////////////// - template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -using SM90_64x32x32_F16E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x32x32_F16E5M2E5M2_RS_TN; +using SM90_64x16x32_F16E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x16x32_F16E5M2E4M3_RS_TN; template -struct MMA_Traits> +struct MMA_Traits> { using ValTypeD = half_t; using ValTypeA = float_e5m2_t; - using ValTypeB = float_e5m2_t; + using ValTypeB = float_e4m3_t; using ValTypeC = half_t; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_32,_32>; + using Shape_MNK = Shape<_64,_16,_32>; using ThrID = Layout<_128>; using ALayout = GMMA::ALayout_64x32; - using BLayout = GMMA::ABLayout< 32, 32>; - using CLayout = GMMA::CLayout_64x32; + using BLayout = GMMA::ABLayout< 16, 32>; + using CLayout = GMMA::CLayout_64x16; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; //////////////////////////////////////////////////////////////////////////////////////////////////// - template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -using SM90_64x32x32_F32E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x32x32_F32E5M2E5M2_SS_TN; +using SM90_64x16x32_F32E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x16x32_F32E5M2E4M3_SS_TN; template -struct MMA_Traits> +struct MMA_Traits> { using ValTypeD = float; using ValTypeA = float_e5m2_t; - using ValTypeB = float_e5m2_t; + using ValTypeB = float_e4m3_t; using ValTypeC = float; using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_32,_32>; + using Shape_MNK = Shape<_64,_16,_32>; using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 32>; - using BLayout = GMMA::ABLayout< 32, 32>; - using CLayout = GMMA::CLayout_64x32; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 16, 32>; + using CLayout = GMMA::CLayout_64x16; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; //////////////////////////////////////////////////////////////////////////////////////////////////// - template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -using SM90_64x32x32_F32E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x32x32_F32E5M2E5M2_RS_TN; +using SM90_64x16x32_F32E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x16x32_F32E5M2E4M3_RS_TN; template -struct MMA_Traits> +struct MMA_Traits> { using ValTypeD = float; using ValTypeA = float_e5m2_t; - using ValTypeB = float_e5m2_t; + using ValTypeB = float_e4m3_t; using ValTypeC = float; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_32,_32>; + using Shape_MNK = Shape<_64,_16,_32>; using ThrID = Layout<_128>; using ALayout = GMMA::ALayout_64x32; - using BLayout = GMMA::ABLayout< 32, 32>; - using CLayout = GMMA::CLayout_64x32; + using BLayout = GMMA::ABLayout< 16, 32>; + using CLayout = GMMA::CLayout_64x16; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) - template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -using SM90_64x48x32_F16E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x48x32_F16E5M2E5M2_SS_TN; +using SM90_64x32x32_F16E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x32x32_F16E5M2E4M3_SS_TN; template -struct MMA_Traits> +struct MMA_Traits> { using ValTypeD = half_t; using ValTypeA = float_e5m2_t; - using ValTypeB = float_e5m2_t; + using ValTypeB = float_e4m3_t; using ValTypeC = half_t; using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_48,_32>; + using Shape_MNK = Shape<_64,_32,_32>; using ThrID = Layout<_128>; using ALayout = GMMA::ABLayout< 64, 32>; - using BLayout = GMMA::ABLayout< 48, 32>; - using CLayout = GMMA::CLayout_64x48; + using BLayout = GMMA::ABLayout< 32, 32>; + using CLayout = GMMA::CLayout_64x32; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) - template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -using SM90_64x48x32_F16E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x48x32_F16E5M2E5M2_RS_TN; +using SM90_64x32x32_F16E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x32x32_F16E5M2E4M3_RS_TN; template -struct MMA_Traits> +struct MMA_Traits> { using ValTypeD = half_t; using ValTypeA = float_e5m2_t; - using ValTypeB = float_e5m2_t; + using ValTypeB = float_e4m3_t; using ValTypeC = half_t; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_48,_32>; + using Shape_MNK = Shape<_64,_32,_32>; using ThrID = Layout<_128>; using ALayout = GMMA::ALayout_64x32; - using BLayout = GMMA::ABLayout< 48, 32>; - using CLayout = GMMA::CLayout_64x48; + using BLayout = GMMA::ABLayout< 32, 32>; + using CLayout = GMMA::CLayout_64x32; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) - template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -using SM90_64x48x32_F32E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x48x32_F32E5M2E5M2_SS_TN; +using SM90_64x32x32_F32E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x32x32_F32E5M2E4M3_SS_TN; template -struct MMA_Traits> +struct MMA_Traits> { using ValTypeD = float; using ValTypeA = float_e5m2_t; - using ValTypeB = float_e5m2_t; + using ValTypeB = float_e4m3_t; using ValTypeC = float; using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_48,_32>; + using Shape_MNK = Shape<_64,_32,_32>; using ThrID = Layout<_128>; using ALayout = GMMA::ABLayout< 64, 32>; - using BLayout = GMMA::ABLayout< 48, 32>; - using CLayout = GMMA::CLayout_64x48; + using BLayout = GMMA::ABLayout< 32, 32>; + using CLayout = GMMA::CLayout_64x32; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) - template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -using SM90_64x48x32_F32E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x48x32_F32E5M2E5M2_RS_TN; +using SM90_64x32x32_F32E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x32x32_F32E5M2E4M3_RS_TN; template -struct MMA_Traits> +struct MMA_Traits> { using ValTypeD = float; using ValTypeA = float_e5m2_t; - using ValTypeB = float_e5m2_t; + using ValTypeB = float_e4m3_t; using ValTypeC = float; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_48,_32>; + using Shape_MNK = Shape<_64,_32,_32>; using ThrID = Layout<_128>; using ALayout = GMMA::ALayout_64x32; - using BLayout = GMMA::ABLayout< 48, 32>; - using CLayout = GMMA::CLayout_64x48; + using BLayout = GMMA::ABLayout< 32, 32>; + using CLayout = GMMA::CLayout_64x32; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// - template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -using SM90_64x64x32_F16E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x64x32_F16E5M2E5M2_SS_TN; +using SM90_64x64x32_F16E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x64x32_F16E5M2E4M3_SS_TN; template -struct MMA_Traits> +struct MMA_Traits> { using ValTypeD = half_t; using ValTypeA = float_e5m2_t; - using ValTypeB = float_e5m2_t; + using ValTypeB = float_e4m3_t; using ValTypeC = half_t; using FrgTypeA = GMMA::smem_desc; @@ -18555,19 +7590,18 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// - template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -using SM90_64x64x32_F16E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x64x32_F16E5M2E5M2_RS_TN; +using SM90_64x64x32_F16E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x64x32_F16E5M2E4M3_RS_TN; template -struct MMA_Traits> +struct MMA_Traits> { using ValTypeD = half_t; using ValTypeA = float_e5m2_t; - using ValTypeB = float_e5m2_t; + using ValTypeB = float_e4m3_t; using ValTypeC = half_t; using FrgTypeB = GMMA::smem_desc; @@ -18583,19 +7617,18 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// - template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -using SM90_64x64x32_F32E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x64x32_F32E5M2E5M2_SS_TN; +using SM90_64x64x32_F32E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x64x32_F32E5M2E4M3_SS_TN; template -struct MMA_Traits> +struct MMA_Traits> { using ValTypeD = float; using ValTypeA = float_e5m2_t; - using ValTypeB = float_e5m2_t; + using ValTypeB = float_e4m3_t; using ValTypeC = float; using FrgTypeA = GMMA::smem_desc; @@ -18612,19 +7645,18 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// - template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -using SM90_64x64x32_F32E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x64x32_F32E5M2E5M2_RS_TN; +using SM90_64x64x32_F32E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x64x32_F32E5M2E4M3_RS_TN; template -struct MMA_Traits> +struct MMA_Traits> { using ValTypeD = float; using ValTypeA = float_e5m2_t; - using ValTypeB = float_e5m2_t; + using ValTypeB = float_e4m3_t; using ValTypeC = float; using FrgTypeB = GMMA::smem_desc; @@ -18640,488 +7672,454 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) - template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -using SM90_64x80x32_F16E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x80x32_F16E5M2E5M2_SS_TN; +using SM90_64x96x32_F16E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x96x32_F16E5M2E4M3_SS_TN; template -struct MMA_Traits> +struct MMA_Traits> { using ValTypeD = half_t; using ValTypeA = float_e5m2_t; - using ValTypeB = float_e5m2_t; + using ValTypeB = float_e4m3_t; using ValTypeC = half_t; using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_80,_32>; + using Shape_MNK = Shape<_64,_96,_32>; using ThrID = Layout<_128>; using ALayout = GMMA::ABLayout< 64, 32>; - using BLayout = GMMA::ABLayout< 80, 32>; - using CLayout = GMMA::CLayout_64x80; + using BLayout = GMMA::ABLayout< 96, 32>; + using CLayout = GMMA::CLayout_64x96; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) - template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -using SM90_64x80x32_F16E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x80x32_F16E5M2E5M2_RS_TN; +using SM90_64x96x32_F16E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x96x32_F16E5M2E4M3_RS_TN; template -struct MMA_Traits> +struct MMA_Traits> { using ValTypeD = half_t; using ValTypeA = float_e5m2_t; - using ValTypeB = float_e5m2_t; + using ValTypeB = float_e4m3_t; using ValTypeC = half_t; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_80,_32>; + using Shape_MNK = Shape<_64,_96,_32>; using ThrID = Layout<_128>; using ALayout = GMMA::ALayout_64x32; - using BLayout = GMMA::ABLayout< 80, 32>; - using CLayout = GMMA::CLayout_64x80; + using BLayout = GMMA::ABLayout< 96, 32>; + using CLayout = GMMA::CLayout_64x96; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) - template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -using SM90_64x80x32_F32E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x80x32_F32E5M2E5M2_SS_TN; +using SM90_64x96x32_F32E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x96x32_F32E5M2E4M3_SS_TN; template -struct MMA_Traits> +struct MMA_Traits> { using ValTypeD = float; using ValTypeA = float_e5m2_t; - using ValTypeB = float_e5m2_t; + using ValTypeB = float_e4m3_t; using ValTypeC = float; using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_80,_32>; + using Shape_MNK = Shape<_64,_96,_32>; using ThrID = Layout<_128>; using ALayout = GMMA::ABLayout< 64, 32>; - using BLayout = GMMA::ABLayout< 80, 32>; - using CLayout = GMMA::CLayout_64x80; + using BLayout = GMMA::ABLayout< 96, 32>; + using CLayout = GMMA::CLayout_64x96; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) - template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -using SM90_64x80x32_F32E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x80x32_F32E5M2E5M2_RS_TN; +using SM90_64x96x32_F32E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x96x32_F32E5M2E4M3_RS_TN; template -struct MMA_Traits> +struct MMA_Traits> { using ValTypeD = float; using ValTypeA = float_e5m2_t; - using ValTypeB = float_e5m2_t; + using ValTypeB = float_e4m3_t; using ValTypeC = float; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_80,_32>; + using Shape_MNK = Shape<_64,_96,_32>; using ThrID = Layout<_128>; using ALayout = GMMA::ALayout_64x32; - using BLayout = GMMA::ABLayout< 80, 32>; - using CLayout = GMMA::CLayout_64x80; + using BLayout = GMMA::ABLayout< 96, 32>; + using CLayout = GMMA::CLayout_64x96; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// - template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -using SM90_64x96x32_F16E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x96x32_F16E5M2E5M2_SS_TN; +using SM90_64x128x32_F16E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x128x32_F16E5M2E4M3_SS_TN; template -struct MMA_Traits> +struct MMA_Traits> { using ValTypeD = half_t; using ValTypeA = float_e5m2_t; - using ValTypeB = float_e5m2_t; + using ValTypeB = float_e4m3_t; using ValTypeC = half_t; using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_96,_32>; + using Shape_MNK = Shape<_64,_128,_32>; using ThrID = Layout<_128>; using ALayout = GMMA::ABLayout< 64, 32>; - using BLayout = GMMA::ABLayout< 96, 32>; - using CLayout = GMMA::CLayout_64x96; + using BLayout = GMMA::ABLayout<128, 32>; + using CLayout = GMMA::CLayout_64x128; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; //////////////////////////////////////////////////////////////////////////////////////////////////// - template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -using SM90_64x96x32_F16E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x96x32_F16E5M2E5M2_RS_TN; +using SM90_64x128x32_F16E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x128x32_F16E5M2E4M3_RS_TN; template -struct MMA_Traits> +struct MMA_Traits> { using ValTypeD = half_t; using ValTypeA = float_e5m2_t; - using ValTypeB = float_e5m2_t; + using ValTypeB = float_e4m3_t; using ValTypeC = half_t; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_96,_32>; + using Shape_MNK = Shape<_64,_128,_32>; using ThrID = Layout<_128>; using ALayout = GMMA::ALayout_64x32; - using BLayout = GMMA::ABLayout< 96, 32>; - using CLayout = GMMA::CLayout_64x96; + using BLayout = GMMA::ABLayout<128, 32>; + using CLayout = GMMA::CLayout_64x128; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; //////////////////////////////////////////////////////////////////////////////////////////////////// - template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -using SM90_64x96x32_F32E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x96x32_F32E5M2E5M2_SS_TN; +using SM90_64x128x32_F32E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x128x32_F32E5M2E4M3_SS_TN; template -struct MMA_Traits> +struct MMA_Traits> { using ValTypeD = float; using ValTypeA = float_e5m2_t; - using ValTypeB = float_e5m2_t; + using ValTypeB = float_e4m3_t; using ValTypeC = float; using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_96,_32>; + using Shape_MNK = Shape<_64,_128,_32>; using ThrID = Layout<_128>; using ALayout = GMMA::ABLayout< 64, 32>; - using BLayout = GMMA::ABLayout< 96, 32>; - using CLayout = GMMA::CLayout_64x96; + using BLayout = GMMA::ABLayout<128, 32>; + using CLayout = GMMA::CLayout_64x128; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; //////////////////////////////////////////////////////////////////////////////////////////////////// - template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -using SM90_64x96x32_F32E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x96x32_F32E5M2E5M2_RS_TN; +using SM90_64x128x32_F32E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x128x32_F32E5M2E4M3_RS_TN; template -struct MMA_Traits> +struct MMA_Traits> { using ValTypeD = float; using ValTypeA = float_e5m2_t; - using ValTypeB = float_e5m2_t; + using ValTypeB = float_e4m3_t; using ValTypeC = float; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_96,_32>; + using Shape_MNK = Shape<_64,_128,_32>; using ThrID = Layout<_128>; using ALayout = GMMA::ALayout_64x32; - using BLayout = GMMA::ABLayout< 96, 32>; - using CLayout = GMMA::CLayout_64x96; + using BLayout = GMMA::ABLayout<128, 32>; + using CLayout = GMMA::CLayout_64x128; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) - template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -using SM90_64x112x32_F16E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x112x32_F16E5M2E5M2_SS_TN; +using SM90_64x192x32_F16E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x192x32_F16E5M2E4M3_SS_TN; template -struct MMA_Traits> +struct MMA_Traits> { using ValTypeD = half_t; using ValTypeA = float_e5m2_t; - using ValTypeB = float_e5m2_t; + using ValTypeB = float_e4m3_t; using ValTypeC = half_t; using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_112,_32>; + using Shape_MNK = Shape<_64,_192,_32>; using ThrID = Layout<_128>; using ALayout = GMMA::ABLayout< 64, 32>; - using BLayout = GMMA::ABLayout<112, 32>; - using CLayout = GMMA::CLayout_64x112; + using BLayout = GMMA::ABLayout<192, 32>; + using CLayout = GMMA::CLayout_64x192; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) - template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -using SM90_64x112x32_F16E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x112x32_F16E5M2E5M2_RS_TN; +using SM90_64x192x32_F16E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x192x32_F16E5M2E4M3_RS_TN; template -struct MMA_Traits> +struct MMA_Traits> { using ValTypeD = half_t; using ValTypeA = float_e5m2_t; - using ValTypeB = float_e5m2_t; + using ValTypeB = float_e4m3_t; using ValTypeC = half_t; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_112,_32>; + using Shape_MNK = Shape<_64,_192,_32>; using ThrID = Layout<_128>; using ALayout = GMMA::ALayout_64x32; - using BLayout = GMMA::ABLayout<112, 32>; - using CLayout = GMMA::CLayout_64x112; + using BLayout = GMMA::ABLayout<192, 32>; + using CLayout = GMMA::CLayout_64x192; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) - template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -using SM90_64x112x32_F32E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x112x32_F32E5M2E5M2_SS_TN; +using SM90_64x192x32_F32E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x192x32_F32E5M2E4M3_SS_TN; template -struct MMA_Traits> +struct MMA_Traits> { using ValTypeD = float; using ValTypeA = float_e5m2_t; - using ValTypeB = float_e5m2_t; + using ValTypeB = float_e4m3_t; using ValTypeC = float; using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_112,_32>; + using Shape_MNK = Shape<_64,_192,_32>; using ThrID = Layout<_128>; using ALayout = GMMA::ABLayout< 64, 32>; - using BLayout = GMMA::ABLayout<112, 32>; - using CLayout = GMMA::CLayout_64x112; + using BLayout = GMMA::ABLayout<192, 32>; + using CLayout = GMMA::CLayout_64x192; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) - template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -using SM90_64x112x32_F32E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x112x32_F32E5M2E5M2_RS_TN; +using SM90_64x192x32_F32E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x192x32_F32E5M2E4M3_RS_TN; template -struct MMA_Traits> +struct MMA_Traits> { using ValTypeD = float; using ValTypeA = float_e5m2_t; - using ValTypeB = float_e5m2_t; + using ValTypeB = float_e4m3_t; using ValTypeC = float; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_112,_32>; + using Shape_MNK = Shape<_64,_192,_32>; using ThrID = Layout<_128>; using ALayout = GMMA::ALayout_64x32; - using BLayout = GMMA::ABLayout<112, 32>; - using CLayout = GMMA::CLayout_64x112; + using BLayout = GMMA::ABLayout<192, 32>; + using CLayout = GMMA::CLayout_64x192; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// - template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -using SM90_64x128x32_F16E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x128x32_F16E5M2E5M2_SS_TN; +using SM90_64x256x32_F16E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x256x32_F16E5M2E4M3_SS_TN; template -struct MMA_Traits> +struct MMA_Traits> { using ValTypeD = half_t; using ValTypeA = float_e5m2_t; - using ValTypeB = float_e5m2_t; + using ValTypeB = float_e4m3_t; using ValTypeC = half_t; using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_128,_32>; + using Shape_MNK = Shape<_64,_256,_32>; using ThrID = Layout<_128>; using ALayout = GMMA::ABLayout< 64, 32>; - using BLayout = GMMA::ABLayout<128, 32>; - using CLayout = GMMA::CLayout_64x128; + using BLayout = GMMA::ABLayout<256, 32>; + using CLayout = GMMA::CLayout_64x256; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; //////////////////////////////////////////////////////////////////////////////////////////////////// - template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -using SM90_64x128x32_F16E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x128x32_F16E5M2E5M2_RS_TN; +using SM90_64x256x32_F16E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x256x32_F16E5M2E4M3_RS_TN; template -struct MMA_Traits> +struct MMA_Traits> { using ValTypeD = half_t; using ValTypeA = float_e5m2_t; - using ValTypeB = float_e5m2_t; + using ValTypeB = float_e4m3_t; using ValTypeC = half_t; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_128,_32>; + using Shape_MNK = Shape<_64,_256,_32>; using ThrID = Layout<_128>; using ALayout = GMMA::ALayout_64x32; - using BLayout = GMMA::ABLayout<128, 32>; - using CLayout = GMMA::CLayout_64x128; + using BLayout = GMMA::ABLayout<256, 32>; + using CLayout = GMMA::CLayout_64x256; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; //////////////////////////////////////////////////////////////////////////////////////////////////// - template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -using SM90_64x128x32_F32E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x128x32_F32E5M2E5M2_SS_TN; +using SM90_64x256x32_F32E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x256x32_F32E5M2E4M3_SS_TN; template -struct MMA_Traits> +struct MMA_Traits> { using ValTypeD = float; using ValTypeA = float_e5m2_t; - using ValTypeB = float_e5m2_t; + using ValTypeB = float_e4m3_t; using ValTypeC = float; using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_128,_32>; + using Shape_MNK = Shape<_64,_256,_32>; using ThrID = Layout<_128>; using ALayout = GMMA::ABLayout< 64, 32>; - using BLayout = GMMA::ABLayout<128, 32>; - using CLayout = GMMA::CLayout_64x128; + using BLayout = GMMA::ABLayout<256, 32>; + using CLayout = GMMA::CLayout_64x256; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; //////////////////////////////////////////////////////////////////////////////////////////////////// - template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -using SM90_64x128x32_F32E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x128x32_F32E5M2E5M2_RS_TN; +using SM90_64x256x32_F32E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x256x32_F32E5M2E4M3_RS_TN; template -struct MMA_Traits> +struct MMA_Traits> { using ValTypeD = float; using ValTypeA = float_e5m2_t; - using ValTypeB = float_e5m2_t; + using ValTypeB = float_e4m3_t; using ValTypeC = float; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_128,_32>; + using Shape_MNK = Shape<_64,_256,_32>; using ThrID = Layout<_128>; using ALayout = GMMA::ALayout_64x32; - using BLayout = GMMA::ABLayout<128, 32>; - using CLayout = GMMA::CLayout_64x128; + using BLayout = GMMA::ABLayout<256, 32>; + using CLayout = GMMA::CLayout_64x256; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) - template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -using SM90_64x144x32_F16E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x144x32_F16E5M2E5M2_SS_TN; +using SM90_64x8x32_F16E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x8x32_F16E5M2E5M2_SS_TN; template -struct MMA_Traits> +struct MMA_Traits> { using ValTypeD = half_t; using ValTypeA = float_e5m2_t; @@ -19131,28 +8129,25 @@ struct MMA_Traits> using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_144,_32>; + using Shape_MNK = Shape<_64,_8,_32>; using ThrID = Layout<_128>; using ALayout = GMMA::ABLayout< 64, 32>; - using BLayout = GMMA::ABLayout<144, 32>; - using CLayout = GMMA::CLayout_64x144; + using BLayout = GMMA::ABLayout< 8, 32>; + using CLayout = GMMA::CLayout_64x8; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) - template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -using SM90_64x144x32_F16E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x144x32_F16E5M2E5M2_RS_TN; +using SM90_64x8x32_F16E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x8x32_F16E5M2E5M2_RS_TN; template -struct MMA_Traits> +struct MMA_Traits> { using ValTypeD = half_t; using ValTypeA = float_e5m2_t; @@ -19161,28 +8156,25 @@ struct MMA_Traits> using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_144,_32>; + using Shape_MNK = Shape<_64,_8,_32>; using ThrID = Layout<_128>; using ALayout = GMMA::ALayout_64x32; - using BLayout = GMMA::ABLayout<144, 32>; - using CLayout = GMMA::CLayout_64x144; + using BLayout = GMMA::ABLayout< 8, 32>; + using CLayout = GMMA::CLayout_64x8; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) - template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -using SM90_64x144x32_F32E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x144x32_F32E5M2E5M2_SS_TN; +using SM90_64x8x32_F32E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x8x32_F32E5M2E5M2_SS_TN; template -struct MMA_Traits> +struct MMA_Traits> { using ValTypeD = float; using ValTypeA = float_e5m2_t; @@ -19192,28 +8184,25 @@ struct MMA_Traits> using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_144,_32>; + using Shape_MNK = Shape<_64,_8,_32>; using ThrID = Layout<_128>; using ALayout = GMMA::ABLayout< 64, 32>; - using BLayout = GMMA::ABLayout<144, 32>; - using CLayout = GMMA::CLayout_64x144; + using BLayout = GMMA::ABLayout< 8, 32>; + using CLayout = GMMA::CLayout_64x8; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) - template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -using SM90_64x144x32_F32E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x144x32_F32E5M2E5M2_RS_TN; +using SM90_64x8x32_F32E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x8x32_F32E5M2E5M2_RS_TN; template -struct MMA_Traits> +struct MMA_Traits> { using ValTypeD = float; using ValTypeA = float_e5m2_t; @@ -19222,28 +8211,25 @@ struct MMA_Traits> using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_144,_32>; + using Shape_MNK = Shape<_64,_8,_32>; using ThrID = Layout<_128>; using ALayout = GMMA::ALayout_64x32; - using BLayout = GMMA::ABLayout<144, 32>; - using CLayout = GMMA::CLayout_64x144; + using BLayout = GMMA::ABLayout< 8, 32>; + using CLayout = GMMA::CLayout_64x8; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) - template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -using SM90_64x160x32_F16E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x160x32_F16E5M2E5M2_SS_TN; +using SM90_64x16x32_F16E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x16x32_F16E5M2E5M2_SS_TN; template -struct MMA_Traits> +struct MMA_Traits> { using ValTypeD = half_t; using ValTypeA = float_e5m2_t; @@ -19253,28 +8239,25 @@ struct MMA_Traits> using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_160,_32>; + using Shape_MNK = Shape<_64,_16,_32>; using ThrID = Layout<_128>; using ALayout = GMMA::ABLayout< 64, 32>; - using BLayout = GMMA::ABLayout<160, 32>; - using CLayout = GMMA::CLayout_64x160; + using BLayout = GMMA::ABLayout< 16, 32>; + using CLayout = GMMA::CLayout_64x16; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) - template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -using SM90_64x160x32_F16E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x160x32_F16E5M2E5M2_RS_TN; +using SM90_64x16x32_F16E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x16x32_F16E5M2E5M2_RS_TN; template -struct MMA_Traits> +struct MMA_Traits> { using ValTypeD = half_t; using ValTypeA = float_e5m2_t; @@ -19283,28 +8266,25 @@ struct MMA_Traits> using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_160,_32>; + using Shape_MNK = Shape<_64,_16,_32>; using ThrID = Layout<_128>; using ALayout = GMMA::ALayout_64x32; - using BLayout = GMMA::ABLayout<160, 32>; - using CLayout = GMMA::CLayout_64x160; + using BLayout = GMMA::ABLayout< 16, 32>; + using CLayout = GMMA::CLayout_64x16; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) - template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -using SM90_64x160x32_F32E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x160x32_F32E5M2E5M2_SS_TN; +using SM90_64x16x32_F32E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x16x32_F32E5M2E5M2_SS_TN; template -struct MMA_Traits> +struct MMA_Traits> { using ValTypeD = float; using ValTypeA = float_e5m2_t; @@ -19314,28 +8294,25 @@ struct MMA_Traits> using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_160,_32>; + using Shape_MNK = Shape<_64,_16,_32>; using ThrID = Layout<_128>; using ALayout = GMMA::ABLayout< 64, 32>; - using BLayout = GMMA::ABLayout<160, 32>; - using CLayout = GMMA::CLayout_64x160; + using BLayout = GMMA::ABLayout< 16, 32>; + using CLayout = GMMA::CLayout_64x16; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) - template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -using SM90_64x160x32_F32E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x160x32_F32E5M2E5M2_RS_TN; +using SM90_64x16x32_F32E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x16x32_F32E5M2E5M2_RS_TN; template -struct MMA_Traits> +struct MMA_Traits> { using ValTypeD = float; using ValTypeA = float_e5m2_t; @@ -19344,28 +8321,25 @@ struct MMA_Traits> using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_160,_32>; + using Shape_MNK = Shape<_64,_16,_32>; using ThrID = Layout<_128>; using ALayout = GMMA::ALayout_64x32; - using BLayout = GMMA::ABLayout<160, 32>; - using CLayout = GMMA::CLayout_64x160; + using BLayout = GMMA::ABLayout< 16, 32>; + using CLayout = GMMA::CLayout_64x16; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) - template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -using SM90_64x176x32_F16E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x176x32_F16E5M2E5M2_SS_TN; +using SM90_64x32x32_F16E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x32x32_F16E5M2E5M2_SS_TN; template -struct MMA_Traits> +struct MMA_Traits> { using ValTypeD = half_t; using ValTypeA = float_e5m2_t; @@ -19375,28 +8349,25 @@ struct MMA_Traits> using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_176,_32>; + using Shape_MNK = Shape<_64,_32,_32>; using ThrID = Layout<_128>; using ALayout = GMMA::ABLayout< 64, 32>; - using BLayout = GMMA::ABLayout<176, 32>; - using CLayout = GMMA::CLayout_64x176; + using BLayout = GMMA::ABLayout< 32, 32>; + using CLayout = GMMA::CLayout_64x32; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) - template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -using SM90_64x176x32_F16E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x176x32_F16E5M2E5M2_RS_TN; +using SM90_64x32x32_F16E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x32x32_F16E5M2E5M2_RS_TN; template -struct MMA_Traits> +struct MMA_Traits> { using ValTypeD = half_t; using ValTypeA = float_e5m2_t; @@ -19405,28 +8376,25 @@ struct MMA_Traits> using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_176,_32>; + using Shape_MNK = Shape<_64,_32,_32>; using ThrID = Layout<_128>; using ALayout = GMMA::ALayout_64x32; - using BLayout = GMMA::ABLayout<176, 32>; - using CLayout = GMMA::CLayout_64x176; + using BLayout = GMMA::ABLayout< 32, 32>; + using CLayout = GMMA::CLayout_64x32; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) - template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -using SM90_64x176x32_F32E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x176x32_F32E5M2E5M2_SS_TN; +using SM90_64x32x32_F32E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x32x32_F32E5M2E5M2_SS_TN; template -struct MMA_Traits> +struct MMA_Traits> { using ValTypeD = float; using ValTypeA = float_e5m2_t; @@ -19436,28 +8404,25 @@ struct MMA_Traits> using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_176,_32>; + using Shape_MNK = Shape<_64,_32,_32>; using ThrID = Layout<_128>; using ALayout = GMMA::ABLayout< 64, 32>; - using BLayout = GMMA::ABLayout<176, 32>; - using CLayout = GMMA::CLayout_64x176; + using BLayout = GMMA::ABLayout< 32, 32>; + using CLayout = GMMA::CLayout_64x32; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) - template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -using SM90_64x176x32_F32E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x176x32_F32E5M2E5M2_RS_TN; +using SM90_64x32x32_F32E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x32x32_F32E5M2E5M2_RS_TN; template -struct MMA_Traits> +struct MMA_Traits> { using ValTypeD = float; using ValTypeA = float_e5m2_t; @@ -19466,27 +8431,25 @@ struct MMA_Traits> using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_176,_32>; + using Shape_MNK = Shape<_64,_32,_32>; using ThrID = Layout<_128>; using ALayout = GMMA::ALayout_64x32; - using BLayout = GMMA::ABLayout<176, 32>; - using CLayout = GMMA::CLayout_64x176; + using BLayout = GMMA::ABLayout< 32, 32>; + using CLayout = GMMA::CLayout_64x32; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// - template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -using SM90_64x192x32_F16E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x192x32_F16E5M2E5M2_SS_TN; +using SM90_64x64x32_F16E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x64x32_F16E5M2E5M2_SS_TN; template -struct MMA_Traits> +struct MMA_Traits> { using ValTypeD = half_t; using ValTypeA = float_e5m2_t; @@ -19496,26 +8459,25 @@ struct MMA_Traits> using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_192,_32>; + using Shape_MNK = Shape<_64,_64,_32>; using ThrID = Layout<_128>; using ALayout = GMMA::ABLayout< 64, 32>; - using BLayout = GMMA::ABLayout<192, 32>; - using CLayout = GMMA::CLayout_64x192; + using BLayout = GMMA::ABLayout< 64, 32>; + using CLayout = GMMA::CLayout_64x64; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; //////////////////////////////////////////////////////////////////////////////////////////////////// - template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -using SM90_64x192x32_F16E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x192x32_F16E5M2E5M2_RS_TN; +using SM90_64x64x32_F16E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x64x32_F16E5M2E5M2_RS_TN; template -struct MMA_Traits> +struct MMA_Traits> { using ValTypeD = half_t; using ValTypeA = float_e5m2_t; @@ -19524,26 +8486,25 @@ struct MMA_Traits> using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_192,_32>; + using Shape_MNK = Shape<_64,_64,_32>; using ThrID = Layout<_128>; using ALayout = GMMA::ALayout_64x32; - using BLayout = GMMA::ABLayout<192, 32>; - using CLayout = GMMA::CLayout_64x192; + using BLayout = GMMA::ABLayout< 64, 32>; + using CLayout = GMMA::CLayout_64x64; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; //////////////////////////////////////////////////////////////////////////////////////////////////// - template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -using SM90_64x192x32_F32E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x192x32_F32E5M2E5M2_SS_TN; +using SM90_64x64x32_F32E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x64x32_F32E5M2E5M2_SS_TN; template -struct MMA_Traits> +struct MMA_Traits> { using ValTypeD = float; using ValTypeA = float_e5m2_t; @@ -19553,26 +8514,25 @@ struct MMA_Traits> using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_192,_32>; + using Shape_MNK = Shape<_64,_64,_32>; using ThrID = Layout<_128>; using ALayout = GMMA::ABLayout< 64, 32>; - using BLayout = GMMA::ABLayout<192, 32>; - using CLayout = GMMA::CLayout_64x192; + using BLayout = GMMA::ABLayout< 64, 32>; + using CLayout = GMMA::CLayout_64x64; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; //////////////////////////////////////////////////////////////////////////////////////////////////// - template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -using SM90_64x192x32_F32E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x192x32_F32E5M2E5M2_RS_TN; +using SM90_64x64x32_F32E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x64x32_F32E5M2E5M2_RS_TN; template -struct MMA_Traits> +struct MMA_Traits> { using ValTypeD = float; using ValTypeA = float_e5m2_t; @@ -19581,27 +8541,25 @@ struct MMA_Traits> using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_192,_32>; + using Shape_MNK = Shape<_64,_64,_32>; using ThrID = Layout<_128>; using ALayout = GMMA::ALayout_64x32; - using BLayout = GMMA::ABLayout<192, 32>; - using CLayout = GMMA::CLayout_64x192; + using BLayout = GMMA::ABLayout< 64, 32>; + using CLayout = GMMA::CLayout_64x64; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) - template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -using SM90_64x208x32_F16E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x208x32_F16E5M2E5M2_SS_TN; +using SM90_64x96x32_F16E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x96x32_F16E5M2E5M2_SS_TN; template -struct MMA_Traits> +struct MMA_Traits> { using ValTypeD = half_t; using ValTypeA = float_e5m2_t; @@ -19611,28 +8569,25 @@ struct MMA_Traits> using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_208,_32>; + using Shape_MNK = Shape<_64,_96,_32>; using ThrID = Layout<_128>; using ALayout = GMMA::ABLayout< 64, 32>; - using BLayout = GMMA::ABLayout<208, 32>; - using CLayout = GMMA::CLayout_64x208; + using BLayout = GMMA::ABLayout< 96, 32>; + using CLayout = GMMA::CLayout_64x96; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) - template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -using SM90_64x208x32_F16E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x208x32_F16E5M2E5M2_RS_TN; +using SM90_64x96x32_F16E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x96x32_F16E5M2E5M2_RS_TN; template -struct MMA_Traits> +struct MMA_Traits> { using ValTypeD = half_t; using ValTypeA = float_e5m2_t; @@ -19641,28 +8596,25 @@ struct MMA_Traits> using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_208,_32>; + using Shape_MNK = Shape<_64,_96,_32>; using ThrID = Layout<_128>; using ALayout = GMMA::ALayout_64x32; - using BLayout = GMMA::ABLayout<208, 32>; - using CLayout = GMMA::CLayout_64x208; + using BLayout = GMMA::ABLayout< 96, 32>; + using CLayout = GMMA::CLayout_64x96; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) - template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -using SM90_64x208x32_F32E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x208x32_F32E5M2E5M2_SS_TN; +using SM90_64x96x32_F32E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x96x32_F32E5M2E5M2_SS_TN; template -struct MMA_Traits> +struct MMA_Traits> { using ValTypeD = float; using ValTypeA = float_e5m2_t; @@ -19672,28 +8624,25 @@ struct MMA_Traits> using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_208,_32>; + using Shape_MNK = Shape<_64,_96,_32>; using ThrID = Layout<_128>; using ALayout = GMMA::ABLayout< 64, 32>; - using BLayout = GMMA::ABLayout<208, 32>; - using CLayout = GMMA::CLayout_64x208; + using BLayout = GMMA::ABLayout< 96, 32>; + using CLayout = GMMA::CLayout_64x96; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) - template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -using SM90_64x208x32_F32E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x208x32_F32E5M2E5M2_RS_TN; +using SM90_64x96x32_F32E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x96x32_F32E5M2E5M2_RS_TN; template -struct MMA_Traits> +struct MMA_Traits> { using ValTypeD = float; using ValTypeA = float_e5m2_t; @@ -19702,28 +8651,25 @@ struct MMA_Traits> using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_208,_32>; + using Shape_MNK = Shape<_64,_96,_32>; using ThrID = Layout<_128>; using ALayout = GMMA::ALayout_64x32; - using BLayout = GMMA::ABLayout<208, 32>; - using CLayout = GMMA::CLayout_64x208; + using BLayout = GMMA::ABLayout< 96, 32>; + using CLayout = GMMA::CLayout_64x96; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) - template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -using SM90_64x224x32_F16E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x224x32_F16E5M2E5M2_SS_TN; +using SM90_64x128x32_F16E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x128x32_F16E5M2E5M2_SS_TN; template -struct MMA_Traits> +struct MMA_Traits> { using ValTypeD = half_t; using ValTypeA = float_e5m2_t; @@ -19733,28 +8679,25 @@ struct MMA_Traits> using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_224,_32>; + using Shape_MNK = Shape<_64,_128,_32>; using ThrID = Layout<_128>; using ALayout = GMMA::ABLayout< 64, 32>; - using BLayout = GMMA::ABLayout<224, 32>; - using CLayout = GMMA::CLayout_64x224; + using BLayout = GMMA::ABLayout<128, 32>; + using CLayout = GMMA::CLayout_64x128; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) - template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -using SM90_64x224x32_F16E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x224x32_F16E5M2E5M2_RS_TN; +using SM90_64x128x32_F16E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x128x32_F16E5M2E5M2_RS_TN; template -struct MMA_Traits> +struct MMA_Traits> { using ValTypeD = half_t; using ValTypeA = float_e5m2_t; @@ -19763,28 +8706,25 @@ struct MMA_Traits> using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_224,_32>; + using Shape_MNK = Shape<_64,_128,_32>; using ThrID = Layout<_128>; using ALayout = GMMA::ALayout_64x32; - using BLayout = GMMA::ABLayout<224, 32>; - using CLayout = GMMA::CLayout_64x224; + using BLayout = GMMA::ABLayout<128, 32>; + using CLayout = GMMA::CLayout_64x128; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) - template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -using SM90_64x224x32_F32E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x224x32_F32E5M2E5M2_SS_TN; +using SM90_64x128x32_F32E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x128x32_F32E5M2E5M2_SS_TN; template -struct MMA_Traits> +struct MMA_Traits> { using ValTypeD = float; using ValTypeA = float_e5m2_t; @@ -19794,28 +8734,25 @@ struct MMA_Traits> using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_224,_32>; + using Shape_MNK = Shape<_64,_128,_32>; using ThrID = Layout<_128>; using ALayout = GMMA::ABLayout< 64, 32>; - using BLayout = GMMA::ABLayout<224, 32>; - using CLayout = GMMA::CLayout_64x224; + using BLayout = GMMA::ABLayout<128, 32>; + using CLayout = GMMA::CLayout_64x128; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) - template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -using SM90_64x224x32_F32E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x224x32_F32E5M2E5M2_RS_TN; +using SM90_64x128x32_F32E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x128x32_F32E5M2E5M2_RS_TN; template -struct MMA_Traits> +struct MMA_Traits> { using ValTypeD = float; using ValTypeA = float_e5m2_t; @@ -19824,28 +8761,25 @@ struct MMA_Traits> using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_224,_32>; + using Shape_MNK = Shape<_64,_128,_32>; using ThrID = Layout<_128>; using ALayout = GMMA::ALayout_64x32; - using BLayout = GMMA::ABLayout<224, 32>; - using CLayout = GMMA::CLayout_64x224; + using BLayout = GMMA::ABLayout<128, 32>; + using CLayout = GMMA::CLayout_64x128; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) - template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -using SM90_64x240x32_F16E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x240x32_F16E5M2E5M2_SS_TN; +using SM90_64x192x32_F16E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x192x32_F16E5M2E5M2_SS_TN; template -struct MMA_Traits> +struct MMA_Traits> { using ValTypeD = half_t; using ValTypeA = float_e5m2_t; @@ -19855,28 +8789,25 @@ struct MMA_Traits> using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_240,_32>; + using Shape_MNK = Shape<_64,_192,_32>; using ThrID = Layout<_128>; using ALayout = GMMA::ABLayout< 64, 32>; - using BLayout = GMMA::ABLayout<240, 32>; - using CLayout = GMMA::CLayout_64x240; + using BLayout = GMMA::ABLayout<192, 32>; + using CLayout = GMMA::CLayout_64x192; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) - template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -using SM90_64x240x32_F16E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x240x32_F16E5M2E5M2_RS_TN; +using SM90_64x192x32_F16E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x192x32_F16E5M2E5M2_RS_TN; template -struct MMA_Traits> +struct MMA_Traits> { using ValTypeD = half_t; using ValTypeA = float_e5m2_t; @@ -19885,28 +8816,25 @@ struct MMA_Traits> using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_240,_32>; + using Shape_MNK = Shape<_64,_192,_32>; using ThrID = Layout<_128>; using ALayout = GMMA::ALayout_64x32; - using BLayout = GMMA::ABLayout<240, 32>; - using CLayout = GMMA::CLayout_64x240; + using BLayout = GMMA::ABLayout<192, 32>; + using CLayout = GMMA::CLayout_64x192; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) - template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -using SM90_64x240x32_F32E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x240x32_F32E5M2E5M2_SS_TN; +using SM90_64x192x32_F32E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x192x32_F32E5M2E5M2_SS_TN; template -struct MMA_Traits> +struct MMA_Traits> { using ValTypeD = float; using ValTypeA = float_e5m2_t; @@ -19916,28 +8844,25 @@ struct MMA_Traits> using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_240,_32>; + using Shape_MNK = Shape<_64,_192,_32>; using ThrID = Layout<_128>; using ALayout = GMMA::ABLayout< 64, 32>; - using BLayout = GMMA::ABLayout<240, 32>; - using CLayout = GMMA::CLayout_64x240; + using BLayout = GMMA::ABLayout<192, 32>; + using CLayout = GMMA::CLayout_64x192; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) - template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One > -using SM90_64x240x32_F32E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x240x32_F32E5M2E5M2_RS_TN; +using SM90_64x192x32_F32E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x192x32_F32E5M2E5M2_RS_TN; template -struct MMA_Traits> +struct MMA_Traits> { using ValTypeD = float; using ValTypeA = float_e5m2_t; @@ -19946,19 +8871,17 @@ struct MMA_Traits> using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_240,_32>; + using Shape_MNK = Shape<_64,_192,_32>; using ThrID = Layout<_128>; using ALayout = GMMA::ALayout_64x32; - using BLayout = GMMA::ABLayout<240, 32>; - using CLayout = GMMA::CLayout_64x240; + using BLayout = GMMA::ABLayout<192, 32>; + using CLayout = GMMA::CLayout_64x192; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// - template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One @@ -19987,7 +8910,6 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// - template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One @@ -20015,7 +8937,6 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// - template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One @@ -20044,7 +8965,6 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// - template < GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, GMMA::ScaleIn scaleB = GMMA::ScaleIn::One @@ -20070,7 +8990,10 @@ struct MMA_Traits> GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; - //////////////////////////////////////////////////////////////////////////////////////////////////// } // end namespace cute + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +#include "mma_traits_sm90_gmma_ext.hpp" +#endif diff --git a/include/cute/atom/mma_traits_sm90_gmma_ext.hpp b/include/cute/atom/mma_traits_sm90_gmma_ext.hpp new file mode 100644 index 000000000..15e2412c8 --- /dev/null +++ b/include/cute/atom/mma_traits_sm90_gmma_ext.hpp @@ -0,0 +1,20116 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include +#include + +namespace cute { + +namespace SM90::GMMA { + +using CLayout_64x24 = CLayout_64xN< 24>; +using CLayout_64x40 = CLayout_64xN< 40>; +using CLayout_64x48 = CLayout_64xN< 48>; +using CLayout_64x56 = CLayout_64xN< 56>; +using CLayout_64x72 = CLayout_64xN< 72>; +using CLayout_64x80 = CLayout_64xN< 80>; +using CLayout_64x88 = CLayout_64xN< 88>; +using CLayout_64x104 = CLayout_64xN<104>; +using CLayout_64x112 = CLayout_64xN<112>; +using CLayout_64x120 = CLayout_64xN<120>; +using CLayout_64x136 = CLayout_64xN<136>; +using CLayout_64x144 = CLayout_64xN<144>; +using CLayout_64x152 = CLayout_64xN<152>; +using CLayout_64x160 = CLayout_64xN<160>; +using CLayout_64x168 = CLayout_64xN<168>; +using CLayout_64x176 = CLayout_64xN<176>; +using CLayout_64x184 = CLayout_64xN<184>; +using CLayout_64x200 = CLayout_64xN<200>; +using CLayout_64x208 = CLayout_64xN<208>; +using CLayout_64x216 = CLayout_64xN<216>; +using CLayout_64x224 = CLayout_64xN<224>; +using CLayout_64x232 = CLayout_64xN<232>; +using CLayout_64x240 = CLayout_64xN<240>; +using CLayout_64x248 = CLayout_64xN<248>; + +} + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x24x16_F16F16F16_SS = SM90::GMMA::MMA_64x24x16_F16F16F16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout< 24, 16>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x24x16_F16F16F16_RS = SM90::GMMA::MMA_64x24x16_F16F16F16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout< 24, 16>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x40x16_F16F16F16_SS = SM90::GMMA::MMA_64x40x16_F16F16F16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_40,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout< 40, 16>; + using CLayout = GMMA::CLayout_64x40; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x40x16_F16F16F16_RS = SM90::GMMA::MMA_64x40x16_F16F16F16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_40,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout< 40, 16>; + using CLayout = GMMA::CLayout_64x40; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x48x16_F16F16F16_SS = SM90::GMMA::MMA_64x48x16_F16F16F16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout< 48, 16>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x48x16_F16F16F16_RS = SM90::GMMA::MMA_64x48x16_F16F16F16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout< 48, 16>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x56x16_F16F16F16_SS = SM90::GMMA::MMA_64x56x16_F16F16F16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_56,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout< 56, 16>; + using CLayout = GMMA::CLayout_64x56; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x56x16_F16F16F16_RS = SM90::GMMA::MMA_64x56x16_F16F16F16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_56,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout< 56, 16>; + using CLayout = GMMA::CLayout_64x56; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x72x16_F16F16F16_SS = SM90::GMMA::MMA_64x72x16_F16F16F16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_72,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout< 72, 16>; + using CLayout = GMMA::CLayout_64x72; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x72x16_F16F16F16_RS = SM90::GMMA::MMA_64x72x16_F16F16F16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_72,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout< 72, 16>; + using CLayout = GMMA::CLayout_64x72; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x80x16_F16F16F16_SS = SM90::GMMA::MMA_64x80x16_F16F16F16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout< 80, 16>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x80x16_F16F16F16_RS = SM90::GMMA::MMA_64x80x16_F16F16F16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout< 80, 16>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x88x16_F16F16F16_SS = SM90::GMMA::MMA_64x88x16_F16F16F16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_88,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout< 88, 16>; + using CLayout = GMMA::CLayout_64x88; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x88x16_F16F16F16_RS = SM90::GMMA::MMA_64x88x16_F16F16F16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_88,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout< 88, 16>; + using CLayout = GMMA::CLayout_64x88; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x104x16_F16F16F16_SS = SM90::GMMA::MMA_64x104x16_F16F16F16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_104,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout<104, 16>; + using CLayout = GMMA::CLayout_64x104; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x104x16_F16F16F16_RS = SM90::GMMA::MMA_64x104x16_F16F16F16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_104,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout<104, 16>; + using CLayout = GMMA::CLayout_64x104; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x112x16_F16F16F16_SS = SM90::GMMA::MMA_64x112x16_F16F16F16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout<112, 16>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x112x16_F16F16F16_RS = SM90::GMMA::MMA_64x112x16_F16F16F16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout<112, 16>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x120x16_F16F16F16_SS = SM90::GMMA::MMA_64x120x16_F16F16F16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_120,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout<120, 16>; + using CLayout = GMMA::CLayout_64x120; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x120x16_F16F16F16_RS = SM90::GMMA::MMA_64x120x16_F16F16F16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_120,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout<120, 16>; + using CLayout = GMMA::CLayout_64x120; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x136x16_F16F16F16_SS = SM90::GMMA::MMA_64x136x16_F16F16F16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_136,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout<136, 16>; + using CLayout = GMMA::CLayout_64x136; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x136x16_F16F16F16_RS = SM90::GMMA::MMA_64x136x16_F16F16F16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_136,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout<136, 16>; + using CLayout = GMMA::CLayout_64x136; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x144x16_F16F16F16_SS = SM90::GMMA::MMA_64x144x16_F16F16F16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout<144, 16>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x144x16_F16F16F16_RS = SM90::GMMA::MMA_64x144x16_F16F16F16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout<144, 16>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x152x16_F16F16F16_SS = SM90::GMMA::MMA_64x152x16_F16F16F16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_152,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout<152, 16>; + using CLayout = GMMA::CLayout_64x152; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x152x16_F16F16F16_RS = SM90::GMMA::MMA_64x152x16_F16F16F16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_152,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout<152, 16>; + using CLayout = GMMA::CLayout_64x152; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x160x16_F16F16F16_SS = SM90::GMMA::MMA_64x160x16_F16F16F16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout<160, 16>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x160x16_F16F16F16_RS = SM90::GMMA::MMA_64x160x16_F16F16F16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout<160, 16>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x168x16_F16F16F16_SS = SM90::GMMA::MMA_64x168x16_F16F16F16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_168,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout<168, 16>; + using CLayout = GMMA::CLayout_64x168; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x168x16_F16F16F16_RS = SM90::GMMA::MMA_64x168x16_F16F16F16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_168,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout<168, 16>; + using CLayout = GMMA::CLayout_64x168; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x176x16_F16F16F16_SS = SM90::GMMA::MMA_64x176x16_F16F16F16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout<176, 16>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x176x16_F16F16F16_RS = SM90::GMMA::MMA_64x176x16_F16F16F16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout<176, 16>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x184x16_F16F16F16_SS = SM90::GMMA::MMA_64x184x16_F16F16F16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_184,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout<184, 16>; + using CLayout = GMMA::CLayout_64x184; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x184x16_F16F16F16_RS = SM90::GMMA::MMA_64x184x16_F16F16F16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_184,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout<184, 16>; + using CLayout = GMMA::CLayout_64x184; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x200x16_F16F16F16_SS = SM90::GMMA::MMA_64x200x16_F16F16F16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_200,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout<200, 16>; + using CLayout = GMMA::CLayout_64x200; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x200x16_F16F16F16_RS = SM90::GMMA::MMA_64x200x16_F16F16F16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_200,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout<200, 16>; + using CLayout = GMMA::CLayout_64x200; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x208x16_F16F16F16_SS = SM90::GMMA::MMA_64x208x16_F16F16F16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout<208, 16>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x208x16_F16F16F16_RS = SM90::GMMA::MMA_64x208x16_F16F16F16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout<208, 16>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x216x16_F16F16F16_SS = SM90::GMMA::MMA_64x216x16_F16F16F16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_216,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout<216, 16>; + using CLayout = GMMA::CLayout_64x216; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x216x16_F16F16F16_RS = SM90::GMMA::MMA_64x216x16_F16F16F16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_216,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout<216, 16>; + using CLayout = GMMA::CLayout_64x216; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x224x16_F16F16F16_SS = SM90::GMMA::MMA_64x224x16_F16F16F16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout<224, 16>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x224x16_F16F16F16_RS = SM90::GMMA::MMA_64x224x16_F16F16F16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout<224, 16>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x232x16_F16F16F16_SS = SM90::GMMA::MMA_64x232x16_F16F16F16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_232,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout<232, 16>; + using CLayout = GMMA::CLayout_64x232; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x232x16_F16F16F16_RS = SM90::GMMA::MMA_64x232x16_F16F16F16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_232,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout<232, 16>; + using CLayout = GMMA::CLayout_64x232; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x240x16_F16F16F16_SS = SM90::GMMA::MMA_64x240x16_F16F16F16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout<240, 16>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x240x16_F16F16F16_RS = SM90::GMMA::MMA_64x240x16_F16F16F16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout<240, 16>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x248x16_F16F16F16_SS = SM90::GMMA::MMA_64x248x16_F16F16F16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_248,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout<248, 16>; + using CLayout = GMMA::CLayout_64x248; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x248x16_F16F16F16_RS = SM90::GMMA::MMA_64x248x16_F16F16F16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_248,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout<248, 16>; + using CLayout = GMMA::CLayout_64x248; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x24x16_F32F16F16_SS = SM90::GMMA::MMA_64x24x16_F32F16F16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout< 24, 16>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x24x16_F32F16F16_RS = SM90::GMMA::MMA_64x24x16_F32F16F16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout< 24, 16>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x40x16_F32F16F16_SS = SM90::GMMA::MMA_64x40x16_F32F16F16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_40,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout< 40, 16>; + using CLayout = GMMA::CLayout_64x40; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x40x16_F32F16F16_RS = SM90::GMMA::MMA_64x40x16_F32F16F16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_40,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout< 40, 16>; + using CLayout = GMMA::CLayout_64x40; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x48x16_F32F16F16_SS = SM90::GMMA::MMA_64x48x16_F32F16F16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout< 48, 16>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x48x16_F32F16F16_RS = SM90::GMMA::MMA_64x48x16_F32F16F16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout< 48, 16>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x56x16_F32F16F16_SS = SM90::GMMA::MMA_64x56x16_F32F16F16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_56,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout< 56, 16>; + using CLayout = GMMA::CLayout_64x56; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x56x16_F32F16F16_RS = SM90::GMMA::MMA_64x56x16_F32F16F16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_56,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout< 56, 16>; + using CLayout = GMMA::CLayout_64x56; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x72x16_F32F16F16_SS = SM90::GMMA::MMA_64x72x16_F32F16F16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_72,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout< 72, 16>; + using CLayout = GMMA::CLayout_64x72; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x72x16_F32F16F16_RS = SM90::GMMA::MMA_64x72x16_F32F16F16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_72,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout< 72, 16>; + using CLayout = GMMA::CLayout_64x72; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x80x16_F32F16F16_SS = SM90::GMMA::MMA_64x80x16_F32F16F16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout< 80, 16>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x80x16_F32F16F16_RS = SM90::GMMA::MMA_64x80x16_F32F16F16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout< 80, 16>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x88x16_F32F16F16_SS = SM90::GMMA::MMA_64x88x16_F32F16F16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_88,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout< 88, 16>; + using CLayout = GMMA::CLayout_64x88; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x88x16_F32F16F16_RS = SM90::GMMA::MMA_64x88x16_F32F16F16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_88,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout< 88, 16>; + using CLayout = GMMA::CLayout_64x88; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x104x16_F32F16F16_SS = SM90::GMMA::MMA_64x104x16_F32F16F16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_104,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout<104, 16>; + using CLayout = GMMA::CLayout_64x104; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x104x16_F32F16F16_RS = SM90::GMMA::MMA_64x104x16_F32F16F16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_104,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout<104, 16>; + using CLayout = GMMA::CLayout_64x104; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x112x16_F32F16F16_SS = SM90::GMMA::MMA_64x112x16_F32F16F16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout<112, 16>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x112x16_F32F16F16_RS = SM90::GMMA::MMA_64x112x16_F32F16F16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout<112, 16>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x120x16_F32F16F16_SS = SM90::GMMA::MMA_64x120x16_F32F16F16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_120,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout<120, 16>; + using CLayout = GMMA::CLayout_64x120; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x120x16_F32F16F16_RS = SM90::GMMA::MMA_64x120x16_F32F16F16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_120,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout<120, 16>; + using CLayout = GMMA::CLayout_64x120; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x136x16_F32F16F16_SS = SM90::GMMA::MMA_64x136x16_F32F16F16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_136,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout<136, 16>; + using CLayout = GMMA::CLayout_64x136; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x136x16_F32F16F16_RS = SM90::GMMA::MMA_64x136x16_F32F16F16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_136,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout<136, 16>; + using CLayout = GMMA::CLayout_64x136; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x144x16_F32F16F16_SS = SM90::GMMA::MMA_64x144x16_F32F16F16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout<144, 16>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x144x16_F32F16F16_RS = SM90::GMMA::MMA_64x144x16_F32F16F16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout<144, 16>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x152x16_F32F16F16_SS = SM90::GMMA::MMA_64x152x16_F32F16F16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_152,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout<152, 16>; + using CLayout = GMMA::CLayout_64x152; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x152x16_F32F16F16_RS = SM90::GMMA::MMA_64x152x16_F32F16F16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_152,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout<152, 16>; + using CLayout = GMMA::CLayout_64x152; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x160x16_F32F16F16_SS = SM90::GMMA::MMA_64x160x16_F32F16F16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout<160, 16>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x160x16_F32F16F16_RS = SM90::GMMA::MMA_64x160x16_F32F16F16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout<160, 16>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x168x16_F32F16F16_SS = SM90::GMMA::MMA_64x168x16_F32F16F16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_168,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout<168, 16>; + using CLayout = GMMA::CLayout_64x168; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x168x16_F32F16F16_RS = SM90::GMMA::MMA_64x168x16_F32F16F16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_168,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout<168, 16>; + using CLayout = GMMA::CLayout_64x168; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x176x16_F32F16F16_SS = SM90::GMMA::MMA_64x176x16_F32F16F16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout<176, 16>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x176x16_F32F16F16_RS = SM90::GMMA::MMA_64x176x16_F32F16F16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout<176, 16>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x184x16_F32F16F16_SS = SM90::GMMA::MMA_64x184x16_F32F16F16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_184,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout<184, 16>; + using CLayout = GMMA::CLayout_64x184; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x184x16_F32F16F16_RS = SM90::GMMA::MMA_64x184x16_F32F16F16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_184,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout<184, 16>; + using CLayout = GMMA::CLayout_64x184; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x200x16_F32F16F16_SS = SM90::GMMA::MMA_64x200x16_F32F16F16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_200,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout<200, 16>; + using CLayout = GMMA::CLayout_64x200; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x200x16_F32F16F16_RS = SM90::GMMA::MMA_64x200x16_F32F16F16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_200,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout<200, 16>; + using CLayout = GMMA::CLayout_64x200; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x208x16_F32F16F16_SS = SM90::GMMA::MMA_64x208x16_F32F16F16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout<208, 16>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x208x16_F32F16F16_RS = SM90::GMMA::MMA_64x208x16_F32F16F16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout<208, 16>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x216x16_F32F16F16_SS = SM90::GMMA::MMA_64x216x16_F32F16F16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_216,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout<216, 16>; + using CLayout = GMMA::CLayout_64x216; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x216x16_F32F16F16_RS = SM90::GMMA::MMA_64x216x16_F32F16F16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_216,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout<216, 16>; + using CLayout = GMMA::CLayout_64x216; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x224x16_F32F16F16_SS = SM90::GMMA::MMA_64x224x16_F32F16F16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout<224, 16>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x224x16_F32F16F16_RS = SM90::GMMA::MMA_64x224x16_F32F16F16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout<224, 16>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x232x16_F32F16F16_SS = SM90::GMMA::MMA_64x232x16_F32F16F16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_232,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout<232, 16>; + using CLayout = GMMA::CLayout_64x232; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x232x16_F32F16F16_RS = SM90::GMMA::MMA_64x232x16_F32F16F16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_232,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout<232, 16>; + using CLayout = GMMA::CLayout_64x232; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x240x16_F32F16F16_SS = SM90::GMMA::MMA_64x240x16_F32F16F16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout<240, 16>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x240x16_F32F16F16_RS = SM90::GMMA::MMA_64x240x16_F32F16F16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout<240, 16>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x248x16_F32F16F16_SS = SM90::GMMA::MMA_64x248x16_F32F16F16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_248,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout<248, 16>; + using CLayout = GMMA::CLayout_64x248; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x248x16_F32F16F16_RS = SM90::GMMA::MMA_64x248x16_F32F16F16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_248,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout<248, 16>; + using CLayout = GMMA::CLayout_64x248; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x24x16_F32BF16BF16_SS = SM90::GMMA::MMA_64x24x16_F32BF16BF16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout< 24, 16>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x24x16_F32BF16BF16_RS = SM90::GMMA::MMA_64x24x16_F32BF16BF16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout< 24, 16>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x40x16_F32BF16BF16_SS = SM90::GMMA::MMA_64x40x16_F32BF16BF16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_40,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout< 40, 16>; + using CLayout = GMMA::CLayout_64x40; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x40x16_F32BF16BF16_RS = SM90::GMMA::MMA_64x40x16_F32BF16BF16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_40,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout< 40, 16>; + using CLayout = GMMA::CLayout_64x40; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x48x16_F32BF16BF16_SS = SM90::GMMA::MMA_64x48x16_F32BF16BF16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout< 48, 16>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x48x16_F32BF16BF16_RS = SM90::GMMA::MMA_64x48x16_F32BF16BF16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout< 48, 16>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x56x16_F32BF16BF16_SS = SM90::GMMA::MMA_64x56x16_F32BF16BF16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_56,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout< 56, 16>; + using CLayout = GMMA::CLayout_64x56; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x56x16_F32BF16BF16_RS = SM90::GMMA::MMA_64x56x16_F32BF16BF16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_56,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout< 56, 16>; + using CLayout = GMMA::CLayout_64x56; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x72x16_F32BF16BF16_SS = SM90::GMMA::MMA_64x72x16_F32BF16BF16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_72,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout< 72, 16>; + using CLayout = GMMA::CLayout_64x72; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x72x16_F32BF16BF16_RS = SM90::GMMA::MMA_64x72x16_F32BF16BF16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_72,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout< 72, 16>; + using CLayout = GMMA::CLayout_64x72; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x80x16_F32BF16BF16_SS = SM90::GMMA::MMA_64x80x16_F32BF16BF16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout< 80, 16>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x80x16_F32BF16BF16_RS = SM90::GMMA::MMA_64x80x16_F32BF16BF16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout< 80, 16>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x88x16_F32BF16BF16_SS = SM90::GMMA::MMA_64x88x16_F32BF16BF16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_88,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout< 88, 16>; + using CLayout = GMMA::CLayout_64x88; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x88x16_F32BF16BF16_RS = SM90::GMMA::MMA_64x88x16_F32BF16BF16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_88,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout< 88, 16>; + using CLayout = GMMA::CLayout_64x88; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x104x16_F32BF16BF16_SS = SM90::GMMA::MMA_64x104x16_F32BF16BF16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_104,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout<104, 16>; + using CLayout = GMMA::CLayout_64x104; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x104x16_F32BF16BF16_RS = SM90::GMMA::MMA_64x104x16_F32BF16BF16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_104,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout<104, 16>; + using CLayout = GMMA::CLayout_64x104; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x112x16_F32BF16BF16_SS = SM90::GMMA::MMA_64x112x16_F32BF16BF16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout<112, 16>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x112x16_F32BF16BF16_RS = SM90::GMMA::MMA_64x112x16_F32BF16BF16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout<112, 16>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x120x16_F32BF16BF16_SS = SM90::GMMA::MMA_64x120x16_F32BF16BF16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_120,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout<120, 16>; + using CLayout = GMMA::CLayout_64x120; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x120x16_F32BF16BF16_RS = SM90::GMMA::MMA_64x120x16_F32BF16BF16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_120,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout<120, 16>; + using CLayout = GMMA::CLayout_64x120; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x136x16_F32BF16BF16_SS = SM90::GMMA::MMA_64x136x16_F32BF16BF16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_136,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout<136, 16>; + using CLayout = GMMA::CLayout_64x136; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x136x16_F32BF16BF16_RS = SM90::GMMA::MMA_64x136x16_F32BF16BF16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_136,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout<136, 16>; + using CLayout = GMMA::CLayout_64x136; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x144x16_F32BF16BF16_SS = SM90::GMMA::MMA_64x144x16_F32BF16BF16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout<144, 16>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x144x16_F32BF16BF16_RS = SM90::GMMA::MMA_64x144x16_F32BF16BF16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout<144, 16>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x152x16_F32BF16BF16_SS = SM90::GMMA::MMA_64x152x16_F32BF16BF16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_152,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout<152, 16>; + using CLayout = GMMA::CLayout_64x152; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x152x16_F32BF16BF16_RS = SM90::GMMA::MMA_64x152x16_F32BF16BF16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_152,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout<152, 16>; + using CLayout = GMMA::CLayout_64x152; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x160x16_F32BF16BF16_SS = SM90::GMMA::MMA_64x160x16_F32BF16BF16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout<160, 16>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x160x16_F32BF16BF16_RS = SM90::GMMA::MMA_64x160x16_F32BF16BF16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout<160, 16>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x168x16_F32BF16BF16_SS = SM90::GMMA::MMA_64x168x16_F32BF16BF16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_168,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout<168, 16>; + using CLayout = GMMA::CLayout_64x168; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x168x16_F32BF16BF16_RS = SM90::GMMA::MMA_64x168x16_F32BF16BF16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_168,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout<168, 16>; + using CLayout = GMMA::CLayout_64x168; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x176x16_F32BF16BF16_SS = SM90::GMMA::MMA_64x176x16_F32BF16BF16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout<176, 16>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x176x16_F32BF16BF16_RS = SM90::GMMA::MMA_64x176x16_F32BF16BF16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout<176, 16>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x184x16_F32BF16BF16_SS = SM90::GMMA::MMA_64x184x16_F32BF16BF16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_184,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout<184, 16>; + using CLayout = GMMA::CLayout_64x184; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x184x16_F32BF16BF16_RS = SM90::GMMA::MMA_64x184x16_F32BF16BF16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_184,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout<184, 16>; + using CLayout = GMMA::CLayout_64x184; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x200x16_F32BF16BF16_SS = SM90::GMMA::MMA_64x200x16_F32BF16BF16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_200,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout<200, 16>; + using CLayout = GMMA::CLayout_64x200; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x200x16_F32BF16BF16_RS = SM90::GMMA::MMA_64x200x16_F32BF16BF16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_200,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout<200, 16>; + using CLayout = GMMA::CLayout_64x200; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x208x16_F32BF16BF16_SS = SM90::GMMA::MMA_64x208x16_F32BF16BF16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout<208, 16>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x208x16_F32BF16BF16_RS = SM90::GMMA::MMA_64x208x16_F32BF16BF16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout<208, 16>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x216x16_F32BF16BF16_SS = SM90::GMMA::MMA_64x216x16_F32BF16BF16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_216,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout<216, 16>; + using CLayout = GMMA::CLayout_64x216; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x216x16_F32BF16BF16_RS = SM90::GMMA::MMA_64x216x16_F32BF16BF16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_216,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout<216, 16>; + using CLayout = GMMA::CLayout_64x216; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x224x16_F32BF16BF16_SS = SM90::GMMA::MMA_64x224x16_F32BF16BF16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout<224, 16>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x224x16_F32BF16BF16_RS = SM90::GMMA::MMA_64x224x16_F32BF16BF16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout<224, 16>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x232x16_F32BF16BF16_SS = SM90::GMMA::MMA_64x232x16_F32BF16BF16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_232,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout<232, 16>; + using CLayout = GMMA::CLayout_64x232; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x232x16_F32BF16BF16_RS = SM90::GMMA::MMA_64x232x16_F32BF16BF16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_232,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout<232, 16>; + using CLayout = GMMA::CLayout_64x232; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x240x16_F32BF16BF16_SS = SM90::GMMA::MMA_64x240x16_F32BF16BF16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout<240, 16>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x240x16_F32BF16BF16_RS = SM90::GMMA::MMA_64x240x16_F32BF16BF16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout<240, 16>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x248x16_F32BF16BF16_SS = SM90::GMMA::MMA_64x248x16_F32BF16BF16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_248,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout<248, 16>; + using CLayout = GMMA::CLayout_64x248; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x248x16_F32BF16BF16_RS = SM90::GMMA::MMA_64x248x16_F32BF16BF16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_248,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout<248, 16>; + using CLayout = GMMA::CLayout_64x248; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x24x8_F32TF32TF32_SS_TN = SM90::GMMA::MMA_64x24x8_F32TF32TF32_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = tfloat32_t; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_8>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 8>; + using BLayout = GMMA::ABLayout< 24, 8>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x24x8_F32TF32TF32_RS_TN = SM90::GMMA::MMA_64x24x8_F32TF32TF32_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = tfloat32_t; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_8>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x8; + using BLayout = GMMA::ABLayout< 24, 8>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x40x8_F32TF32TF32_SS_TN = SM90::GMMA::MMA_64x40x8_F32TF32TF32_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = tfloat32_t; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_40,_8>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 8>; + using BLayout = GMMA::ABLayout< 40, 8>; + using CLayout = GMMA::CLayout_64x40; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x40x8_F32TF32TF32_RS_TN = SM90::GMMA::MMA_64x40x8_F32TF32TF32_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = tfloat32_t; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_40,_8>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x8; + using BLayout = GMMA::ABLayout< 40, 8>; + using CLayout = GMMA::CLayout_64x40; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x48x8_F32TF32TF32_SS_TN = SM90::GMMA::MMA_64x48x8_F32TF32TF32_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = tfloat32_t; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_8>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 8>; + using BLayout = GMMA::ABLayout< 48, 8>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x48x8_F32TF32TF32_RS_TN = SM90::GMMA::MMA_64x48x8_F32TF32TF32_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = tfloat32_t; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_8>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x8; + using BLayout = GMMA::ABLayout< 48, 8>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x56x8_F32TF32TF32_SS_TN = SM90::GMMA::MMA_64x56x8_F32TF32TF32_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = tfloat32_t; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_56,_8>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 8>; + using BLayout = GMMA::ABLayout< 56, 8>; + using CLayout = GMMA::CLayout_64x56; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x56x8_F32TF32TF32_RS_TN = SM90::GMMA::MMA_64x56x8_F32TF32TF32_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = tfloat32_t; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_56,_8>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x8; + using BLayout = GMMA::ABLayout< 56, 8>; + using CLayout = GMMA::CLayout_64x56; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x72x8_F32TF32TF32_SS_TN = SM90::GMMA::MMA_64x72x8_F32TF32TF32_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = tfloat32_t; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_72,_8>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 8>; + using BLayout = GMMA::ABLayout< 72, 8>; + using CLayout = GMMA::CLayout_64x72; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x72x8_F32TF32TF32_RS_TN = SM90::GMMA::MMA_64x72x8_F32TF32TF32_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = tfloat32_t; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_72,_8>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x8; + using BLayout = GMMA::ABLayout< 72, 8>; + using CLayout = GMMA::CLayout_64x72; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x80x8_F32TF32TF32_SS_TN = SM90::GMMA::MMA_64x80x8_F32TF32TF32_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = tfloat32_t; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_8>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 8>; + using BLayout = GMMA::ABLayout< 80, 8>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x80x8_F32TF32TF32_RS_TN = SM90::GMMA::MMA_64x80x8_F32TF32TF32_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = tfloat32_t; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_8>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x8; + using BLayout = GMMA::ABLayout< 80, 8>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x88x8_F32TF32TF32_SS_TN = SM90::GMMA::MMA_64x88x8_F32TF32TF32_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = tfloat32_t; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_88,_8>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 8>; + using BLayout = GMMA::ABLayout< 88, 8>; + using CLayout = GMMA::CLayout_64x88; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x88x8_F32TF32TF32_RS_TN = SM90::GMMA::MMA_64x88x8_F32TF32TF32_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = tfloat32_t; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_88,_8>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x8; + using BLayout = GMMA::ABLayout< 88, 8>; + using CLayout = GMMA::CLayout_64x88; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x104x8_F32TF32TF32_SS_TN = SM90::GMMA::MMA_64x104x8_F32TF32TF32_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = tfloat32_t; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_104,_8>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 8>; + using BLayout = GMMA::ABLayout<104, 8>; + using CLayout = GMMA::CLayout_64x104; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x104x8_F32TF32TF32_RS_TN = SM90::GMMA::MMA_64x104x8_F32TF32TF32_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = tfloat32_t; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_104,_8>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x8; + using BLayout = GMMA::ABLayout<104, 8>; + using CLayout = GMMA::CLayout_64x104; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x112x8_F32TF32TF32_SS_TN = SM90::GMMA::MMA_64x112x8_F32TF32TF32_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = tfloat32_t; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_8>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 8>; + using BLayout = GMMA::ABLayout<112, 8>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x112x8_F32TF32TF32_RS_TN = SM90::GMMA::MMA_64x112x8_F32TF32TF32_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = tfloat32_t; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_8>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x8; + using BLayout = GMMA::ABLayout<112, 8>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x120x8_F32TF32TF32_SS_TN = SM90::GMMA::MMA_64x120x8_F32TF32TF32_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = tfloat32_t; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_120,_8>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 8>; + using BLayout = GMMA::ABLayout<120, 8>; + using CLayout = GMMA::CLayout_64x120; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x120x8_F32TF32TF32_RS_TN = SM90::GMMA::MMA_64x120x8_F32TF32TF32_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = tfloat32_t; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_120,_8>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x8; + using BLayout = GMMA::ABLayout<120, 8>; + using CLayout = GMMA::CLayout_64x120; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x136x8_F32TF32TF32_SS_TN = SM90::GMMA::MMA_64x136x8_F32TF32TF32_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = tfloat32_t; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_136,_8>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 8>; + using BLayout = GMMA::ABLayout<136, 8>; + using CLayout = GMMA::CLayout_64x136; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x136x8_F32TF32TF32_RS_TN = SM90::GMMA::MMA_64x136x8_F32TF32TF32_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = tfloat32_t; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_136,_8>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x8; + using BLayout = GMMA::ABLayout<136, 8>; + using CLayout = GMMA::CLayout_64x136; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x144x8_F32TF32TF32_SS_TN = SM90::GMMA::MMA_64x144x8_F32TF32TF32_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = tfloat32_t; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_8>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 8>; + using BLayout = GMMA::ABLayout<144, 8>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x144x8_F32TF32TF32_RS_TN = SM90::GMMA::MMA_64x144x8_F32TF32TF32_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = tfloat32_t; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_8>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x8; + using BLayout = GMMA::ABLayout<144, 8>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x152x8_F32TF32TF32_SS_TN = SM90::GMMA::MMA_64x152x8_F32TF32TF32_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = tfloat32_t; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_152,_8>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 8>; + using BLayout = GMMA::ABLayout<152, 8>; + using CLayout = GMMA::CLayout_64x152; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x152x8_F32TF32TF32_RS_TN = SM90::GMMA::MMA_64x152x8_F32TF32TF32_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = tfloat32_t; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_152,_8>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x8; + using BLayout = GMMA::ABLayout<152, 8>; + using CLayout = GMMA::CLayout_64x152; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x160x8_F32TF32TF32_SS_TN = SM90::GMMA::MMA_64x160x8_F32TF32TF32_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = tfloat32_t; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_8>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 8>; + using BLayout = GMMA::ABLayout<160, 8>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x160x8_F32TF32TF32_RS_TN = SM90::GMMA::MMA_64x160x8_F32TF32TF32_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = tfloat32_t; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_8>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x8; + using BLayout = GMMA::ABLayout<160, 8>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x168x8_F32TF32TF32_SS_TN = SM90::GMMA::MMA_64x168x8_F32TF32TF32_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = tfloat32_t; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_168,_8>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 8>; + using BLayout = GMMA::ABLayout<168, 8>; + using CLayout = GMMA::CLayout_64x168; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x168x8_F32TF32TF32_RS_TN = SM90::GMMA::MMA_64x168x8_F32TF32TF32_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = tfloat32_t; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_168,_8>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x8; + using BLayout = GMMA::ABLayout<168, 8>; + using CLayout = GMMA::CLayout_64x168; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x176x8_F32TF32TF32_SS_TN = SM90::GMMA::MMA_64x176x8_F32TF32TF32_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = tfloat32_t; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_8>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 8>; + using BLayout = GMMA::ABLayout<176, 8>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x176x8_F32TF32TF32_RS_TN = SM90::GMMA::MMA_64x176x8_F32TF32TF32_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = tfloat32_t; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_8>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x8; + using BLayout = GMMA::ABLayout<176, 8>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x184x8_F32TF32TF32_SS_TN = SM90::GMMA::MMA_64x184x8_F32TF32TF32_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = tfloat32_t; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_184,_8>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 8>; + using BLayout = GMMA::ABLayout<184, 8>; + using CLayout = GMMA::CLayout_64x184; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x184x8_F32TF32TF32_RS_TN = SM90::GMMA::MMA_64x184x8_F32TF32TF32_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = tfloat32_t; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_184,_8>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x8; + using BLayout = GMMA::ABLayout<184, 8>; + using CLayout = GMMA::CLayout_64x184; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x200x8_F32TF32TF32_SS_TN = SM90::GMMA::MMA_64x200x8_F32TF32TF32_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = tfloat32_t; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_200,_8>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 8>; + using BLayout = GMMA::ABLayout<200, 8>; + using CLayout = GMMA::CLayout_64x200; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x200x8_F32TF32TF32_RS_TN = SM90::GMMA::MMA_64x200x8_F32TF32TF32_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = tfloat32_t; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_200,_8>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x8; + using BLayout = GMMA::ABLayout<200, 8>; + using CLayout = GMMA::CLayout_64x200; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x208x8_F32TF32TF32_SS_TN = SM90::GMMA::MMA_64x208x8_F32TF32TF32_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = tfloat32_t; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_8>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 8>; + using BLayout = GMMA::ABLayout<208, 8>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x208x8_F32TF32TF32_RS_TN = SM90::GMMA::MMA_64x208x8_F32TF32TF32_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = tfloat32_t; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_8>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x8; + using BLayout = GMMA::ABLayout<208, 8>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x216x8_F32TF32TF32_SS_TN = SM90::GMMA::MMA_64x216x8_F32TF32TF32_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = tfloat32_t; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_216,_8>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 8>; + using BLayout = GMMA::ABLayout<216, 8>; + using CLayout = GMMA::CLayout_64x216; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x216x8_F32TF32TF32_RS_TN = SM90::GMMA::MMA_64x216x8_F32TF32TF32_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = tfloat32_t; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_216,_8>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x8; + using BLayout = GMMA::ABLayout<216, 8>; + using CLayout = GMMA::CLayout_64x216; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x224x8_F32TF32TF32_SS_TN = SM90::GMMA::MMA_64x224x8_F32TF32TF32_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = tfloat32_t; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_8>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 8>; + using BLayout = GMMA::ABLayout<224, 8>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x224x8_F32TF32TF32_RS_TN = SM90::GMMA::MMA_64x224x8_F32TF32TF32_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = tfloat32_t; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_8>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x8; + using BLayout = GMMA::ABLayout<224, 8>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x232x8_F32TF32TF32_SS_TN = SM90::GMMA::MMA_64x232x8_F32TF32TF32_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = tfloat32_t; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_232,_8>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 8>; + using BLayout = GMMA::ABLayout<232, 8>; + using CLayout = GMMA::CLayout_64x232; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x232x8_F32TF32TF32_RS_TN = SM90::GMMA::MMA_64x232x8_F32TF32TF32_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = tfloat32_t; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_232,_8>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x8; + using BLayout = GMMA::ABLayout<232, 8>; + using CLayout = GMMA::CLayout_64x232; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x240x8_F32TF32TF32_SS_TN = SM90::GMMA::MMA_64x240x8_F32TF32TF32_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = tfloat32_t; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_8>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 8>; + using BLayout = GMMA::ABLayout<240, 8>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x240x8_F32TF32TF32_RS_TN = SM90::GMMA::MMA_64x240x8_F32TF32TF32_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = tfloat32_t; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_8>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x8; + using BLayout = GMMA::ABLayout<240, 8>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x248x8_F32TF32TF32_SS_TN = SM90::GMMA::MMA_64x248x8_F32TF32TF32_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = tfloat32_t; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_248,_8>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 8>; + using BLayout = GMMA::ABLayout<248, 8>; + using CLayout = GMMA::CLayout_64x248; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x248x8_F32TF32TF32_RS_TN = SM90::GMMA::MMA_64x248x8_F32TF32TF32_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = tfloat32_t; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_248,_8>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x8; + using BLayout = GMMA::ABLayout<248, 8>; + using CLayout = GMMA::CLayout_64x248; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x24x32_S32S8S8_SS_TN = SM90::GMMA::MMA_64x24x32_S32S8S8_SS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 24, 32>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x24x32_S32S8S8_SS_TN_SATURATE = SM90::GMMA::MMA_64x24x32_S32S8S8_SS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 24, 32>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x48x32_S32S8S8_SS_TN = SM90::GMMA::MMA_64x48x32_S32S8S8_SS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 48, 32>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x48x32_S32S8S8_SS_TN_SATURATE = SM90::GMMA::MMA_64x48x32_S32S8S8_SS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 48, 32>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x80x32_S32S8S8_SS_TN = SM90::GMMA::MMA_64x80x32_S32S8S8_SS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 80, 32>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x80x32_S32S8S8_SS_TN_SATURATE = SM90::GMMA::MMA_64x80x32_S32S8S8_SS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 80, 32>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x112x32_S32S8S8_SS_TN = SM90::GMMA::MMA_64x112x32_S32S8S8_SS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<112, 32>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x112x32_S32S8S8_SS_TN_SATURATE = SM90::GMMA::MMA_64x112x32_S32S8S8_SS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<112, 32>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x144x32_S32S8S8_SS_TN = SM90::GMMA::MMA_64x144x32_S32S8S8_SS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<144, 32>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x144x32_S32S8S8_SS_TN_SATURATE = SM90::GMMA::MMA_64x144x32_S32S8S8_SS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<144, 32>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x160x32_S32S8S8_SS_TN = SM90::GMMA::MMA_64x160x32_S32S8S8_SS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<160, 32>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x160x32_S32S8S8_SS_TN_SATURATE = SM90::GMMA::MMA_64x160x32_S32S8S8_SS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<160, 32>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x176x32_S32S8S8_SS_TN = SM90::GMMA::MMA_64x176x32_S32S8S8_SS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<176, 32>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x176x32_S32S8S8_SS_TN_SATURATE = SM90::GMMA::MMA_64x176x32_S32S8S8_SS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<176, 32>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x208x32_S32S8S8_SS_TN = SM90::GMMA::MMA_64x208x32_S32S8S8_SS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<208, 32>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x208x32_S32S8S8_SS_TN_SATURATE = SM90::GMMA::MMA_64x208x32_S32S8S8_SS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<208, 32>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x224x32_S32S8S8_SS_TN = SM90::GMMA::MMA_64x224x32_S32S8S8_SS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<224, 32>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x224x32_S32S8S8_SS_TN_SATURATE = SM90::GMMA::MMA_64x224x32_S32S8S8_SS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<224, 32>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x240x32_S32S8S8_SS_TN = SM90::GMMA::MMA_64x240x32_S32S8S8_SS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<240, 32>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x240x32_S32S8S8_SS_TN_SATURATE = SM90::GMMA::MMA_64x240x32_S32S8S8_SS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<240, 32>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x24x32_S32S8S8_RS_TN = SM90::GMMA::MMA_64x24x32_S32S8S8_RS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 24, 32>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x24x32_S32S8S8_RS_TN_SATURATE = SM90::GMMA::MMA_64x24x32_S32S8S8_RS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 24, 32>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x48x32_S32S8S8_RS_TN = SM90::GMMA::MMA_64x48x32_S32S8S8_RS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 48, 32>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x48x32_S32S8S8_RS_TN_SATURATE = SM90::GMMA::MMA_64x48x32_S32S8S8_RS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 48, 32>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x80x32_S32S8S8_RS_TN = SM90::GMMA::MMA_64x80x32_S32S8S8_RS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 80, 32>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x80x32_S32S8S8_RS_TN_SATURATE = SM90::GMMA::MMA_64x80x32_S32S8S8_RS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 80, 32>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x112x32_S32S8S8_RS_TN = SM90::GMMA::MMA_64x112x32_S32S8S8_RS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<112, 32>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x112x32_S32S8S8_RS_TN_SATURATE = SM90::GMMA::MMA_64x112x32_S32S8S8_RS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<112, 32>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x144x32_S32S8S8_RS_TN = SM90::GMMA::MMA_64x144x32_S32S8S8_RS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<144, 32>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x144x32_S32S8S8_RS_TN_SATURATE = SM90::GMMA::MMA_64x144x32_S32S8S8_RS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<144, 32>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x160x32_S32S8S8_RS_TN = SM90::GMMA::MMA_64x160x32_S32S8S8_RS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<160, 32>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x160x32_S32S8S8_RS_TN_SATURATE = SM90::GMMA::MMA_64x160x32_S32S8S8_RS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<160, 32>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x176x32_S32S8S8_RS_TN = SM90::GMMA::MMA_64x176x32_S32S8S8_RS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<176, 32>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x176x32_S32S8S8_RS_TN_SATURATE = SM90::GMMA::MMA_64x176x32_S32S8S8_RS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<176, 32>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x208x32_S32S8S8_RS_TN = SM90::GMMA::MMA_64x208x32_S32S8S8_RS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<208, 32>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x208x32_S32S8S8_RS_TN_SATURATE = SM90::GMMA::MMA_64x208x32_S32S8S8_RS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<208, 32>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x224x32_S32S8S8_RS_TN = SM90::GMMA::MMA_64x224x32_S32S8S8_RS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<224, 32>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x224x32_S32S8S8_RS_TN_SATURATE = SM90::GMMA::MMA_64x224x32_S32S8S8_RS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<224, 32>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x240x32_S32S8S8_RS_TN = SM90::GMMA::MMA_64x240x32_S32S8S8_RS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<240, 32>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x240x32_S32S8S8_RS_TN_SATURATE = SM90::GMMA::MMA_64x240x32_S32S8S8_RS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<240, 32>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x24x32_S32S8U8_SS_TN = SM90::GMMA::MMA_64x24x32_S32S8U8_SS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 24, 32>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x24x32_S32S8U8_SS_TN_SATURATE = SM90::GMMA::MMA_64x24x32_S32S8U8_SS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 24, 32>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x48x32_S32S8U8_SS_TN = SM90::GMMA::MMA_64x48x32_S32S8U8_SS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 48, 32>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x48x32_S32S8U8_SS_TN_SATURATE = SM90::GMMA::MMA_64x48x32_S32S8U8_SS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 48, 32>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x80x32_S32S8U8_SS_TN = SM90::GMMA::MMA_64x80x32_S32S8U8_SS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 80, 32>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x80x32_S32S8U8_SS_TN_SATURATE = SM90::GMMA::MMA_64x80x32_S32S8U8_SS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 80, 32>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x112x32_S32S8U8_SS_TN = SM90::GMMA::MMA_64x112x32_S32S8U8_SS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<112, 32>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x112x32_S32S8U8_SS_TN_SATURATE = SM90::GMMA::MMA_64x112x32_S32S8U8_SS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<112, 32>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x144x32_S32S8U8_SS_TN = SM90::GMMA::MMA_64x144x32_S32S8U8_SS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<144, 32>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x144x32_S32S8U8_SS_TN_SATURATE = SM90::GMMA::MMA_64x144x32_S32S8U8_SS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<144, 32>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x160x32_S32S8U8_SS_TN = SM90::GMMA::MMA_64x160x32_S32S8U8_SS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<160, 32>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x160x32_S32S8U8_SS_TN_SATURATE = SM90::GMMA::MMA_64x160x32_S32S8U8_SS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<160, 32>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x176x32_S32S8U8_SS_TN = SM90::GMMA::MMA_64x176x32_S32S8U8_SS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<176, 32>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x176x32_S32S8U8_SS_TN_SATURATE = SM90::GMMA::MMA_64x176x32_S32S8U8_SS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<176, 32>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x208x32_S32S8U8_SS_TN = SM90::GMMA::MMA_64x208x32_S32S8U8_SS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<208, 32>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x208x32_S32S8U8_SS_TN_SATURATE = SM90::GMMA::MMA_64x208x32_S32S8U8_SS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<208, 32>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x224x32_S32S8U8_SS_TN = SM90::GMMA::MMA_64x224x32_S32S8U8_SS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<224, 32>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x224x32_S32S8U8_SS_TN_SATURATE = SM90::GMMA::MMA_64x224x32_S32S8U8_SS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<224, 32>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x240x32_S32S8U8_SS_TN = SM90::GMMA::MMA_64x240x32_S32S8U8_SS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<240, 32>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x240x32_S32S8U8_SS_TN_SATURATE = SM90::GMMA::MMA_64x240x32_S32S8U8_SS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<240, 32>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x24x32_S32S8U8_RS_TN = SM90::GMMA::MMA_64x24x32_S32S8U8_RS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 24, 32>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x24x32_S32S8U8_RS_TN_SATURATE = SM90::GMMA::MMA_64x24x32_S32S8U8_RS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 24, 32>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x48x32_S32S8U8_RS_TN = SM90::GMMA::MMA_64x48x32_S32S8U8_RS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 48, 32>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x48x32_S32S8U8_RS_TN_SATURATE = SM90::GMMA::MMA_64x48x32_S32S8U8_RS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 48, 32>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x80x32_S32S8U8_RS_TN = SM90::GMMA::MMA_64x80x32_S32S8U8_RS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 80, 32>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x80x32_S32S8U8_RS_TN_SATURATE = SM90::GMMA::MMA_64x80x32_S32S8U8_RS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 80, 32>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x112x32_S32S8U8_RS_TN = SM90::GMMA::MMA_64x112x32_S32S8U8_RS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<112, 32>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x112x32_S32S8U8_RS_TN_SATURATE = SM90::GMMA::MMA_64x112x32_S32S8U8_RS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<112, 32>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x144x32_S32S8U8_RS_TN = SM90::GMMA::MMA_64x144x32_S32S8U8_RS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<144, 32>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x144x32_S32S8U8_RS_TN_SATURATE = SM90::GMMA::MMA_64x144x32_S32S8U8_RS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<144, 32>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x160x32_S32S8U8_RS_TN = SM90::GMMA::MMA_64x160x32_S32S8U8_RS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<160, 32>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x160x32_S32S8U8_RS_TN_SATURATE = SM90::GMMA::MMA_64x160x32_S32S8U8_RS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<160, 32>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x176x32_S32S8U8_RS_TN = SM90::GMMA::MMA_64x176x32_S32S8U8_RS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<176, 32>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x176x32_S32S8U8_RS_TN_SATURATE = SM90::GMMA::MMA_64x176x32_S32S8U8_RS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<176, 32>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x208x32_S32S8U8_RS_TN = SM90::GMMA::MMA_64x208x32_S32S8U8_RS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<208, 32>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x208x32_S32S8U8_RS_TN_SATURATE = SM90::GMMA::MMA_64x208x32_S32S8U8_RS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<208, 32>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x224x32_S32S8U8_RS_TN = SM90::GMMA::MMA_64x224x32_S32S8U8_RS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<224, 32>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x224x32_S32S8U8_RS_TN_SATURATE = SM90::GMMA::MMA_64x224x32_S32S8U8_RS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<224, 32>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x240x32_S32S8U8_RS_TN = SM90::GMMA::MMA_64x240x32_S32S8U8_RS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<240, 32>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x240x32_S32S8U8_RS_TN_SATURATE = SM90::GMMA::MMA_64x240x32_S32S8U8_RS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<240, 32>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x24x32_S32U8S8_SS_TN = SM90::GMMA::MMA_64x24x32_S32U8S8_SS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 24, 32>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x24x32_S32U8S8_SS_TN_SATURATE = SM90::GMMA::MMA_64x24x32_S32U8S8_SS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 24, 32>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x48x32_S32U8S8_SS_TN = SM90::GMMA::MMA_64x48x32_S32U8S8_SS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 48, 32>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x48x32_S32U8S8_SS_TN_SATURATE = SM90::GMMA::MMA_64x48x32_S32U8S8_SS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 48, 32>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x80x32_S32U8S8_SS_TN = SM90::GMMA::MMA_64x80x32_S32U8S8_SS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 80, 32>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x80x32_S32U8S8_SS_TN_SATURATE = SM90::GMMA::MMA_64x80x32_S32U8S8_SS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 80, 32>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x112x32_S32U8S8_SS_TN = SM90::GMMA::MMA_64x112x32_S32U8S8_SS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<112, 32>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x112x32_S32U8S8_SS_TN_SATURATE = SM90::GMMA::MMA_64x112x32_S32U8S8_SS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<112, 32>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x144x32_S32U8S8_SS_TN = SM90::GMMA::MMA_64x144x32_S32U8S8_SS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<144, 32>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x144x32_S32U8S8_SS_TN_SATURATE = SM90::GMMA::MMA_64x144x32_S32U8S8_SS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<144, 32>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x160x32_S32U8S8_SS_TN = SM90::GMMA::MMA_64x160x32_S32U8S8_SS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<160, 32>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x160x32_S32U8S8_SS_TN_SATURATE = SM90::GMMA::MMA_64x160x32_S32U8S8_SS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<160, 32>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x176x32_S32U8S8_SS_TN = SM90::GMMA::MMA_64x176x32_S32U8S8_SS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<176, 32>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x176x32_S32U8S8_SS_TN_SATURATE = SM90::GMMA::MMA_64x176x32_S32U8S8_SS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<176, 32>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x208x32_S32U8S8_SS_TN = SM90::GMMA::MMA_64x208x32_S32U8S8_SS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<208, 32>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x208x32_S32U8S8_SS_TN_SATURATE = SM90::GMMA::MMA_64x208x32_S32U8S8_SS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<208, 32>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x224x32_S32U8S8_SS_TN = SM90::GMMA::MMA_64x224x32_S32U8S8_SS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<224, 32>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x224x32_S32U8S8_SS_TN_SATURATE = SM90::GMMA::MMA_64x224x32_S32U8S8_SS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<224, 32>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x240x32_S32U8S8_SS_TN = SM90::GMMA::MMA_64x240x32_S32U8S8_SS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<240, 32>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x240x32_S32U8S8_SS_TN_SATURATE = SM90::GMMA::MMA_64x240x32_S32U8S8_SS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<240, 32>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x24x32_S32U8S8_RS_TN = SM90::GMMA::MMA_64x24x32_S32U8S8_RS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 24, 32>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x24x32_S32U8S8_RS_TN_SATURATE = SM90::GMMA::MMA_64x24x32_S32U8S8_RS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 24, 32>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x48x32_S32U8S8_RS_TN = SM90::GMMA::MMA_64x48x32_S32U8S8_RS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 48, 32>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x48x32_S32U8S8_RS_TN_SATURATE = SM90::GMMA::MMA_64x48x32_S32U8S8_RS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 48, 32>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x80x32_S32U8S8_RS_TN = SM90::GMMA::MMA_64x80x32_S32U8S8_RS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 80, 32>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x80x32_S32U8S8_RS_TN_SATURATE = SM90::GMMA::MMA_64x80x32_S32U8S8_RS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 80, 32>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x112x32_S32U8S8_RS_TN = SM90::GMMA::MMA_64x112x32_S32U8S8_RS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<112, 32>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x112x32_S32U8S8_RS_TN_SATURATE = SM90::GMMA::MMA_64x112x32_S32U8S8_RS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<112, 32>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x144x32_S32U8S8_RS_TN = SM90::GMMA::MMA_64x144x32_S32U8S8_RS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<144, 32>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x144x32_S32U8S8_RS_TN_SATURATE = SM90::GMMA::MMA_64x144x32_S32U8S8_RS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<144, 32>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x160x32_S32U8S8_RS_TN = SM90::GMMA::MMA_64x160x32_S32U8S8_RS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<160, 32>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x160x32_S32U8S8_RS_TN_SATURATE = SM90::GMMA::MMA_64x160x32_S32U8S8_RS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<160, 32>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x176x32_S32U8S8_RS_TN = SM90::GMMA::MMA_64x176x32_S32U8S8_RS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<176, 32>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x176x32_S32U8S8_RS_TN_SATURATE = SM90::GMMA::MMA_64x176x32_S32U8S8_RS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<176, 32>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x208x32_S32U8S8_RS_TN = SM90::GMMA::MMA_64x208x32_S32U8S8_RS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<208, 32>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x208x32_S32U8S8_RS_TN_SATURATE = SM90::GMMA::MMA_64x208x32_S32U8S8_RS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<208, 32>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x224x32_S32U8S8_RS_TN = SM90::GMMA::MMA_64x224x32_S32U8S8_RS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<224, 32>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x224x32_S32U8S8_RS_TN_SATURATE = SM90::GMMA::MMA_64x224x32_S32U8S8_RS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<224, 32>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x240x32_S32U8S8_RS_TN = SM90::GMMA::MMA_64x240x32_S32U8S8_RS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<240, 32>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x240x32_S32U8S8_RS_TN_SATURATE = SM90::GMMA::MMA_64x240x32_S32U8S8_RS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<240, 32>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x24x32_S32U8U8_SS_TN = SM90::GMMA::MMA_64x24x32_S32U8U8_SS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 24, 32>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x24x32_S32U8U8_SS_TN_SATURATE = SM90::GMMA::MMA_64x24x32_S32U8U8_SS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 24, 32>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x48x32_S32U8U8_SS_TN = SM90::GMMA::MMA_64x48x32_S32U8U8_SS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 48, 32>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x48x32_S32U8U8_SS_TN_SATURATE = SM90::GMMA::MMA_64x48x32_S32U8U8_SS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 48, 32>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x80x32_S32U8U8_SS_TN = SM90::GMMA::MMA_64x80x32_S32U8U8_SS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 80, 32>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x80x32_S32U8U8_SS_TN_SATURATE = SM90::GMMA::MMA_64x80x32_S32U8U8_SS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 80, 32>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x112x32_S32U8U8_SS_TN = SM90::GMMA::MMA_64x112x32_S32U8U8_SS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<112, 32>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x112x32_S32U8U8_SS_TN_SATURATE = SM90::GMMA::MMA_64x112x32_S32U8U8_SS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<112, 32>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x144x32_S32U8U8_SS_TN = SM90::GMMA::MMA_64x144x32_S32U8U8_SS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<144, 32>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x144x32_S32U8U8_SS_TN_SATURATE = SM90::GMMA::MMA_64x144x32_S32U8U8_SS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<144, 32>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x160x32_S32U8U8_SS_TN = SM90::GMMA::MMA_64x160x32_S32U8U8_SS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<160, 32>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x160x32_S32U8U8_SS_TN_SATURATE = SM90::GMMA::MMA_64x160x32_S32U8U8_SS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<160, 32>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x176x32_S32U8U8_SS_TN = SM90::GMMA::MMA_64x176x32_S32U8U8_SS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<176, 32>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x176x32_S32U8U8_SS_TN_SATURATE = SM90::GMMA::MMA_64x176x32_S32U8U8_SS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<176, 32>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x208x32_S32U8U8_SS_TN = SM90::GMMA::MMA_64x208x32_S32U8U8_SS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<208, 32>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x208x32_S32U8U8_SS_TN_SATURATE = SM90::GMMA::MMA_64x208x32_S32U8U8_SS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<208, 32>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x224x32_S32U8U8_SS_TN = SM90::GMMA::MMA_64x224x32_S32U8U8_SS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<224, 32>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x224x32_S32U8U8_SS_TN_SATURATE = SM90::GMMA::MMA_64x224x32_S32U8U8_SS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<224, 32>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x240x32_S32U8U8_SS_TN = SM90::GMMA::MMA_64x240x32_S32U8U8_SS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<240, 32>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x240x32_S32U8U8_SS_TN_SATURATE = SM90::GMMA::MMA_64x240x32_S32U8U8_SS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<240, 32>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x24x32_S32U8U8_RS_TN = SM90::GMMA::MMA_64x24x32_S32U8U8_RS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 24, 32>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x24x32_S32U8U8_RS_TN_SATURATE = SM90::GMMA::MMA_64x24x32_S32U8U8_RS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 24, 32>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x48x32_S32U8U8_RS_TN = SM90::GMMA::MMA_64x48x32_S32U8U8_RS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 48, 32>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x48x32_S32U8U8_RS_TN_SATURATE = SM90::GMMA::MMA_64x48x32_S32U8U8_RS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 48, 32>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x80x32_S32U8U8_RS_TN = SM90::GMMA::MMA_64x80x32_S32U8U8_RS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 80, 32>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x80x32_S32U8U8_RS_TN_SATURATE = SM90::GMMA::MMA_64x80x32_S32U8U8_RS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 80, 32>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x112x32_S32U8U8_RS_TN = SM90::GMMA::MMA_64x112x32_S32U8U8_RS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<112, 32>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x112x32_S32U8U8_RS_TN_SATURATE = SM90::GMMA::MMA_64x112x32_S32U8U8_RS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<112, 32>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x144x32_S32U8U8_RS_TN = SM90::GMMA::MMA_64x144x32_S32U8U8_RS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<144, 32>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x144x32_S32U8U8_RS_TN_SATURATE = SM90::GMMA::MMA_64x144x32_S32U8U8_RS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<144, 32>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x160x32_S32U8U8_RS_TN = SM90::GMMA::MMA_64x160x32_S32U8U8_RS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<160, 32>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x160x32_S32U8U8_RS_TN_SATURATE = SM90::GMMA::MMA_64x160x32_S32U8U8_RS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<160, 32>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x176x32_S32U8U8_RS_TN = SM90::GMMA::MMA_64x176x32_S32U8U8_RS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<176, 32>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x176x32_S32U8U8_RS_TN_SATURATE = SM90::GMMA::MMA_64x176x32_S32U8U8_RS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<176, 32>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x208x32_S32U8U8_RS_TN = SM90::GMMA::MMA_64x208x32_S32U8U8_RS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<208, 32>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x208x32_S32U8U8_RS_TN_SATURATE = SM90::GMMA::MMA_64x208x32_S32U8U8_RS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<208, 32>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x224x32_S32U8U8_RS_TN = SM90::GMMA::MMA_64x224x32_S32U8U8_RS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<224, 32>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x224x32_S32U8U8_RS_TN_SATURATE = SM90::GMMA::MMA_64x224x32_S32U8U8_RS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<224, 32>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x240x32_S32U8U8_RS_TN = SM90::GMMA::MMA_64x240x32_S32U8U8_RS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<240, 32>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x240x32_S32U8U8_RS_TN_SATURATE = SM90::GMMA::MMA_64x240x32_S32U8U8_RS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<240, 32>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x24x32_F16E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x24x32_F16E4M3E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 24, 32>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x24x32_F16E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x24x32_F16E4M3E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 24, 32>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x24x32_F32E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x24x32_F32E4M3E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 24, 32>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x24x32_F32E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x24x32_F32E4M3E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 24, 32>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x40x32_F16E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x40x32_F16E4M3E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_40,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 40, 32>; + using CLayout = GMMA::CLayout_64x40; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x40x32_F16E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x40x32_F16E4M3E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_40,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 40, 32>; + using CLayout = GMMA::CLayout_64x40; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x40x32_F32E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x40x32_F32E4M3E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_40,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 40, 32>; + using CLayout = GMMA::CLayout_64x40; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x40x32_F32E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x40x32_F32E4M3E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_40,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 40, 32>; + using CLayout = GMMA::CLayout_64x40; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x48x32_F16E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x48x32_F16E4M3E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 48, 32>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x48x32_F16E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x48x32_F16E4M3E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 48, 32>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x48x32_F32E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x48x32_F32E4M3E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 48, 32>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x48x32_F32E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x48x32_F32E4M3E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 48, 32>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x56x32_F16E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x56x32_F16E4M3E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_56,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 56, 32>; + using CLayout = GMMA::CLayout_64x56; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x56x32_F16E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x56x32_F16E4M3E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_56,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 56, 32>; + using CLayout = GMMA::CLayout_64x56; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x56x32_F32E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x56x32_F32E4M3E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_56,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 56, 32>; + using CLayout = GMMA::CLayout_64x56; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x56x32_F32E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x56x32_F32E4M3E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_56,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 56, 32>; + using CLayout = GMMA::CLayout_64x56; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x72x32_F16E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x72x32_F16E4M3E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_72,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 72, 32>; + using CLayout = GMMA::CLayout_64x72; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x72x32_F16E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x72x32_F16E4M3E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_72,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 72, 32>; + using CLayout = GMMA::CLayout_64x72; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x72x32_F32E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x72x32_F32E4M3E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_72,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 72, 32>; + using CLayout = GMMA::CLayout_64x72; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x72x32_F32E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x72x32_F32E4M3E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_72,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 72, 32>; + using CLayout = GMMA::CLayout_64x72; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x80x32_F16E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x80x32_F16E4M3E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 80, 32>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x80x32_F16E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x80x32_F16E4M3E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 80, 32>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x80x32_F32E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x80x32_F32E4M3E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 80, 32>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x80x32_F32E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x80x32_F32E4M3E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 80, 32>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x88x32_F16E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x88x32_F16E4M3E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_88,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 88, 32>; + using CLayout = GMMA::CLayout_64x88; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x88x32_F16E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x88x32_F16E4M3E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_88,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 88, 32>; + using CLayout = GMMA::CLayout_64x88; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x88x32_F32E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x88x32_F32E4M3E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_88,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 88, 32>; + using CLayout = GMMA::CLayout_64x88; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x88x32_F32E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x88x32_F32E4M3E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_88,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 88, 32>; + using CLayout = GMMA::CLayout_64x88; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x104x32_F16E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x104x32_F16E4M3E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_104,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<104, 32>; + using CLayout = GMMA::CLayout_64x104; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x104x32_F16E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x104x32_F16E4M3E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_104,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<104, 32>; + using CLayout = GMMA::CLayout_64x104; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x104x32_F32E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x104x32_F32E4M3E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_104,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<104, 32>; + using CLayout = GMMA::CLayout_64x104; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x104x32_F32E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x104x32_F32E4M3E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_104,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<104, 32>; + using CLayout = GMMA::CLayout_64x104; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x112x32_F16E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x112x32_F16E4M3E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<112, 32>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x112x32_F16E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x112x32_F16E4M3E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<112, 32>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x112x32_F32E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x112x32_F32E4M3E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<112, 32>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x112x32_F32E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x112x32_F32E4M3E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<112, 32>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x120x32_F16E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x120x32_F16E4M3E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_120,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<120, 32>; + using CLayout = GMMA::CLayout_64x120; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x120x32_F16E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x120x32_F16E4M3E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_120,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<120, 32>; + using CLayout = GMMA::CLayout_64x120; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x120x32_F32E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x120x32_F32E4M3E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_120,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<120, 32>; + using CLayout = GMMA::CLayout_64x120; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x120x32_F32E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x120x32_F32E4M3E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_120,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<120, 32>; + using CLayout = GMMA::CLayout_64x120; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x136x32_F16E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x136x32_F16E4M3E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_136,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<136, 32>; + using CLayout = GMMA::CLayout_64x136; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x136x32_F16E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x136x32_F16E4M3E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_136,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<136, 32>; + using CLayout = GMMA::CLayout_64x136; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x136x32_F32E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x136x32_F32E4M3E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_136,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<136, 32>; + using CLayout = GMMA::CLayout_64x136; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x136x32_F32E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x136x32_F32E4M3E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_136,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<136, 32>; + using CLayout = GMMA::CLayout_64x136; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x144x32_F16E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x144x32_F16E4M3E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<144, 32>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x144x32_F16E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x144x32_F16E4M3E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<144, 32>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x144x32_F32E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x144x32_F32E4M3E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<144, 32>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x144x32_F32E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x144x32_F32E4M3E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<144, 32>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x152x32_F16E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x152x32_F16E4M3E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_152,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<152, 32>; + using CLayout = GMMA::CLayout_64x152; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x152x32_F16E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x152x32_F16E4M3E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_152,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<152, 32>; + using CLayout = GMMA::CLayout_64x152; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x152x32_F32E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x152x32_F32E4M3E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_152,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<152, 32>; + using CLayout = GMMA::CLayout_64x152; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x152x32_F32E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x152x32_F32E4M3E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_152,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<152, 32>; + using CLayout = GMMA::CLayout_64x152; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x160x32_F16E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x160x32_F16E4M3E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<160, 32>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x160x32_F16E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x160x32_F16E4M3E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<160, 32>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x160x32_F32E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x160x32_F32E4M3E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<160, 32>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x160x32_F32E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x160x32_F32E4M3E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<160, 32>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x168x32_F16E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x168x32_F16E4M3E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_168,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<168, 32>; + using CLayout = GMMA::CLayout_64x168; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x168x32_F16E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x168x32_F16E4M3E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_168,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<168, 32>; + using CLayout = GMMA::CLayout_64x168; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x168x32_F32E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x168x32_F32E4M3E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_168,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<168, 32>; + using CLayout = GMMA::CLayout_64x168; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x168x32_F32E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x168x32_F32E4M3E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_168,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<168, 32>; + using CLayout = GMMA::CLayout_64x168; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x176x32_F16E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x176x32_F16E4M3E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<176, 32>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x176x32_F16E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x176x32_F16E4M3E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<176, 32>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x176x32_F32E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x176x32_F32E4M3E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<176, 32>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x176x32_F32E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x176x32_F32E4M3E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<176, 32>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x184x32_F16E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x184x32_F16E4M3E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_184,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<184, 32>; + using CLayout = GMMA::CLayout_64x184; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x184x32_F16E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x184x32_F16E4M3E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_184,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<184, 32>; + using CLayout = GMMA::CLayout_64x184; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x184x32_F32E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x184x32_F32E4M3E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_184,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<184, 32>; + using CLayout = GMMA::CLayout_64x184; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x184x32_F32E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x184x32_F32E4M3E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_184,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<184, 32>; + using CLayout = GMMA::CLayout_64x184; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x200x32_F16E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x200x32_F16E4M3E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_200,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<200, 32>; + using CLayout = GMMA::CLayout_64x200; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x200x32_F16E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x200x32_F16E4M3E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_200,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<200, 32>; + using CLayout = GMMA::CLayout_64x200; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x200x32_F32E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x200x32_F32E4M3E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_200,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<200, 32>; + using CLayout = GMMA::CLayout_64x200; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x200x32_F32E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x200x32_F32E4M3E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_200,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<200, 32>; + using CLayout = GMMA::CLayout_64x200; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x208x32_F16E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x208x32_F16E4M3E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<208, 32>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x208x32_F16E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x208x32_F16E4M3E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<208, 32>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x208x32_F32E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x208x32_F32E4M3E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<208, 32>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x208x32_F32E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x208x32_F32E4M3E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<208, 32>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x216x32_F16E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x216x32_F16E4M3E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_216,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<216, 32>; + using CLayout = GMMA::CLayout_64x216; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x216x32_F16E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x216x32_F16E4M3E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_216,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<216, 32>; + using CLayout = GMMA::CLayout_64x216; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x216x32_F32E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x216x32_F32E4M3E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_216,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<216, 32>; + using CLayout = GMMA::CLayout_64x216; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x216x32_F32E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x216x32_F32E4M3E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_216,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<216, 32>; + using CLayout = GMMA::CLayout_64x216; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x224x32_F16E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x224x32_F16E4M3E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<224, 32>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x224x32_F16E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x224x32_F16E4M3E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<224, 32>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x224x32_F32E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x224x32_F32E4M3E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<224, 32>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x224x32_F32E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x224x32_F32E4M3E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<224, 32>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x232x32_F16E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x232x32_F16E4M3E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_232,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<232, 32>; + using CLayout = GMMA::CLayout_64x232; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x232x32_F16E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x232x32_F16E4M3E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_232,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<232, 32>; + using CLayout = GMMA::CLayout_64x232; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x232x32_F32E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x232x32_F32E4M3E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_232,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<232, 32>; + using CLayout = GMMA::CLayout_64x232; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x232x32_F32E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x232x32_F32E4M3E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_232,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<232, 32>; + using CLayout = GMMA::CLayout_64x232; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x240x32_F16E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x240x32_F16E4M3E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<240, 32>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x240x32_F16E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x240x32_F16E4M3E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<240, 32>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x240x32_F32E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x240x32_F32E4M3E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<240, 32>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x240x32_F32E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x240x32_F32E4M3E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<240, 32>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x248x32_F16E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x248x32_F16E4M3E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_248,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<248, 32>; + using CLayout = GMMA::CLayout_64x248; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x248x32_F16E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x248x32_F16E4M3E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_248,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<248, 32>; + using CLayout = GMMA::CLayout_64x248; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x248x32_F32E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x248x32_F32E4M3E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_248,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<248, 32>; + using CLayout = GMMA::CLayout_64x248; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x248x32_F32E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x248x32_F32E4M3E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_248,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<248, 32>; + using CLayout = GMMA::CLayout_64x248; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x24x32_F16E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x24x32_F16E4M3E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 24, 32>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x24x32_F16E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x24x32_F16E4M3E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 24, 32>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x24x32_F32E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x24x32_F32E4M3E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 24, 32>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x24x32_F32E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x24x32_F32E4M3E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 24, 32>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x40x32_F16E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x40x32_F16E4M3E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_40,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 40, 32>; + using CLayout = GMMA::CLayout_64x40; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x40x32_F16E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x40x32_F16E4M3E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_40,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 40, 32>; + using CLayout = GMMA::CLayout_64x40; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x40x32_F32E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x40x32_F32E4M3E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_40,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 40, 32>; + using CLayout = GMMA::CLayout_64x40; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x40x32_F32E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x40x32_F32E4M3E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_40,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 40, 32>; + using CLayout = GMMA::CLayout_64x40; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x48x32_F16E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x48x32_F16E4M3E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 48, 32>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x48x32_F16E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x48x32_F16E4M3E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 48, 32>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x48x32_F32E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x48x32_F32E4M3E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 48, 32>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x48x32_F32E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x48x32_F32E4M3E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 48, 32>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x56x32_F16E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x56x32_F16E4M3E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_56,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 56, 32>; + using CLayout = GMMA::CLayout_64x56; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x56x32_F16E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x56x32_F16E4M3E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_56,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 56, 32>; + using CLayout = GMMA::CLayout_64x56; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x56x32_F32E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x56x32_F32E4M3E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_56,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 56, 32>; + using CLayout = GMMA::CLayout_64x56; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x56x32_F32E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x56x32_F32E4M3E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_56,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 56, 32>; + using CLayout = GMMA::CLayout_64x56; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x72x32_F16E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x72x32_F16E4M3E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_72,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 72, 32>; + using CLayout = GMMA::CLayout_64x72; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x72x32_F16E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x72x32_F16E4M3E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_72,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 72, 32>; + using CLayout = GMMA::CLayout_64x72; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x72x32_F32E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x72x32_F32E4M3E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_72,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 72, 32>; + using CLayout = GMMA::CLayout_64x72; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x72x32_F32E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x72x32_F32E4M3E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_72,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 72, 32>; + using CLayout = GMMA::CLayout_64x72; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x80x32_F16E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x80x32_F16E4M3E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 80, 32>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x80x32_F16E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x80x32_F16E4M3E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 80, 32>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x80x32_F32E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x80x32_F32E4M3E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 80, 32>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x80x32_F32E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x80x32_F32E4M3E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 80, 32>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x88x32_F16E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x88x32_F16E4M3E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_88,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 88, 32>; + using CLayout = GMMA::CLayout_64x88; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x88x32_F16E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x88x32_F16E4M3E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_88,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 88, 32>; + using CLayout = GMMA::CLayout_64x88; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x88x32_F32E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x88x32_F32E4M3E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_88,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 88, 32>; + using CLayout = GMMA::CLayout_64x88; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x88x32_F32E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x88x32_F32E4M3E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_88,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 88, 32>; + using CLayout = GMMA::CLayout_64x88; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x104x32_F16E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x104x32_F16E4M3E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_104,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<104, 32>; + using CLayout = GMMA::CLayout_64x104; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x104x32_F16E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x104x32_F16E4M3E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_104,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<104, 32>; + using CLayout = GMMA::CLayout_64x104; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x104x32_F32E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x104x32_F32E4M3E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_104,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<104, 32>; + using CLayout = GMMA::CLayout_64x104; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x104x32_F32E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x104x32_F32E4M3E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_104,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<104, 32>; + using CLayout = GMMA::CLayout_64x104; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x112x32_F16E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x112x32_F16E4M3E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<112, 32>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x112x32_F16E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x112x32_F16E4M3E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<112, 32>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x112x32_F32E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x112x32_F32E4M3E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<112, 32>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x112x32_F32E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x112x32_F32E4M3E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<112, 32>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x120x32_F16E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x120x32_F16E4M3E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_120,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<120, 32>; + using CLayout = GMMA::CLayout_64x120; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x120x32_F16E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x120x32_F16E4M3E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_120,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<120, 32>; + using CLayout = GMMA::CLayout_64x120; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x120x32_F32E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x120x32_F32E4M3E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_120,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<120, 32>; + using CLayout = GMMA::CLayout_64x120; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x120x32_F32E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x120x32_F32E4M3E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_120,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<120, 32>; + using CLayout = GMMA::CLayout_64x120; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x136x32_F16E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x136x32_F16E4M3E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_136,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<136, 32>; + using CLayout = GMMA::CLayout_64x136; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x136x32_F16E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x136x32_F16E4M3E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_136,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<136, 32>; + using CLayout = GMMA::CLayout_64x136; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x136x32_F32E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x136x32_F32E4M3E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_136,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<136, 32>; + using CLayout = GMMA::CLayout_64x136; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x136x32_F32E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x136x32_F32E4M3E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_136,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<136, 32>; + using CLayout = GMMA::CLayout_64x136; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x144x32_F16E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x144x32_F16E4M3E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<144, 32>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x144x32_F16E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x144x32_F16E4M3E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<144, 32>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x144x32_F32E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x144x32_F32E4M3E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<144, 32>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x144x32_F32E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x144x32_F32E4M3E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<144, 32>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x152x32_F16E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x152x32_F16E4M3E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_152,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<152, 32>; + using CLayout = GMMA::CLayout_64x152; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x152x32_F16E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x152x32_F16E4M3E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_152,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<152, 32>; + using CLayout = GMMA::CLayout_64x152; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x152x32_F32E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x152x32_F32E4M3E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_152,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<152, 32>; + using CLayout = GMMA::CLayout_64x152; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x152x32_F32E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x152x32_F32E4M3E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_152,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<152, 32>; + using CLayout = GMMA::CLayout_64x152; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x160x32_F16E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x160x32_F16E4M3E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<160, 32>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x160x32_F16E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x160x32_F16E4M3E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<160, 32>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x160x32_F32E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x160x32_F32E4M3E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<160, 32>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x160x32_F32E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x160x32_F32E4M3E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<160, 32>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x168x32_F16E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x168x32_F16E4M3E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_168,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<168, 32>; + using CLayout = GMMA::CLayout_64x168; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x168x32_F16E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x168x32_F16E4M3E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_168,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<168, 32>; + using CLayout = GMMA::CLayout_64x168; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x168x32_F32E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x168x32_F32E4M3E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_168,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<168, 32>; + using CLayout = GMMA::CLayout_64x168; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x168x32_F32E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x168x32_F32E4M3E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_168,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<168, 32>; + using CLayout = GMMA::CLayout_64x168; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x176x32_F16E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x176x32_F16E4M3E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<176, 32>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x176x32_F16E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x176x32_F16E4M3E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<176, 32>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x176x32_F32E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x176x32_F32E4M3E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<176, 32>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x176x32_F32E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x176x32_F32E4M3E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<176, 32>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x184x32_F16E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x184x32_F16E4M3E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_184,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<184, 32>; + using CLayout = GMMA::CLayout_64x184; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x184x32_F16E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x184x32_F16E4M3E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_184,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<184, 32>; + using CLayout = GMMA::CLayout_64x184; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x184x32_F32E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x184x32_F32E4M3E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_184,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<184, 32>; + using CLayout = GMMA::CLayout_64x184; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x184x32_F32E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x184x32_F32E4M3E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_184,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<184, 32>; + using CLayout = GMMA::CLayout_64x184; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x200x32_F16E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x200x32_F16E4M3E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_200,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<200, 32>; + using CLayout = GMMA::CLayout_64x200; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x200x32_F16E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x200x32_F16E4M3E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_200,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<200, 32>; + using CLayout = GMMA::CLayout_64x200; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x200x32_F32E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x200x32_F32E4M3E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_200,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<200, 32>; + using CLayout = GMMA::CLayout_64x200; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x200x32_F32E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x200x32_F32E4M3E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_200,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<200, 32>; + using CLayout = GMMA::CLayout_64x200; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x208x32_F16E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x208x32_F16E4M3E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<208, 32>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x208x32_F16E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x208x32_F16E4M3E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<208, 32>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x208x32_F32E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x208x32_F32E4M3E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<208, 32>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x208x32_F32E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x208x32_F32E4M3E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<208, 32>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x216x32_F16E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x216x32_F16E4M3E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_216,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<216, 32>; + using CLayout = GMMA::CLayout_64x216; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x216x32_F16E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x216x32_F16E4M3E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_216,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<216, 32>; + using CLayout = GMMA::CLayout_64x216; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x216x32_F32E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x216x32_F32E4M3E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_216,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<216, 32>; + using CLayout = GMMA::CLayout_64x216; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x216x32_F32E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x216x32_F32E4M3E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_216,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<216, 32>; + using CLayout = GMMA::CLayout_64x216; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x224x32_F16E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x224x32_F16E4M3E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<224, 32>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x224x32_F16E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x224x32_F16E4M3E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<224, 32>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x224x32_F32E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x224x32_F32E4M3E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<224, 32>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x224x32_F32E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x224x32_F32E4M3E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<224, 32>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x232x32_F16E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x232x32_F16E4M3E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_232,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<232, 32>; + using CLayout = GMMA::CLayout_64x232; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x232x32_F16E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x232x32_F16E4M3E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_232,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<232, 32>; + using CLayout = GMMA::CLayout_64x232; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x232x32_F32E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x232x32_F32E4M3E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_232,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<232, 32>; + using CLayout = GMMA::CLayout_64x232; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x232x32_F32E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x232x32_F32E4M3E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_232,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<232, 32>; + using CLayout = GMMA::CLayout_64x232; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x240x32_F16E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x240x32_F16E4M3E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<240, 32>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x240x32_F16E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x240x32_F16E4M3E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<240, 32>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x240x32_F32E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x240x32_F32E4M3E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<240, 32>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x240x32_F32E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x240x32_F32E4M3E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<240, 32>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x248x32_F16E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x248x32_F16E4M3E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_248,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<248, 32>; + using CLayout = GMMA::CLayout_64x248; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x248x32_F16E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x248x32_F16E4M3E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_248,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<248, 32>; + using CLayout = GMMA::CLayout_64x248; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x248x32_F32E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x248x32_F32E4M3E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_248,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<248, 32>; + using CLayout = GMMA::CLayout_64x248; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x248x32_F32E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x248x32_F32E4M3E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_248,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<248, 32>; + using CLayout = GMMA::CLayout_64x248; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x24x32_F16E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x24x32_F16E5M2E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 24, 32>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x24x32_F16E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x24x32_F16E5M2E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 24, 32>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x24x32_F32E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x24x32_F32E5M2E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 24, 32>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x24x32_F32E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x24x32_F32E5M2E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 24, 32>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x40x32_F16E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x40x32_F16E5M2E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_40,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 40, 32>; + using CLayout = GMMA::CLayout_64x40; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x40x32_F16E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x40x32_F16E5M2E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_40,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 40, 32>; + using CLayout = GMMA::CLayout_64x40; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x40x32_F32E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x40x32_F32E5M2E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_40,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 40, 32>; + using CLayout = GMMA::CLayout_64x40; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x40x32_F32E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x40x32_F32E5M2E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_40,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 40, 32>; + using CLayout = GMMA::CLayout_64x40; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x48x32_F16E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x48x32_F16E5M2E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 48, 32>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x48x32_F16E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x48x32_F16E5M2E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 48, 32>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x48x32_F32E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x48x32_F32E5M2E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 48, 32>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x48x32_F32E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x48x32_F32E5M2E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 48, 32>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x56x32_F16E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x56x32_F16E5M2E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_56,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 56, 32>; + using CLayout = GMMA::CLayout_64x56; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x56x32_F16E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x56x32_F16E5M2E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_56,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 56, 32>; + using CLayout = GMMA::CLayout_64x56; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x56x32_F32E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x56x32_F32E5M2E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_56,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 56, 32>; + using CLayout = GMMA::CLayout_64x56; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x56x32_F32E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x56x32_F32E5M2E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_56,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 56, 32>; + using CLayout = GMMA::CLayout_64x56; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x72x32_F16E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x72x32_F16E5M2E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_72,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 72, 32>; + using CLayout = GMMA::CLayout_64x72; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x72x32_F16E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x72x32_F16E5M2E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_72,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 72, 32>; + using CLayout = GMMA::CLayout_64x72; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x72x32_F32E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x72x32_F32E5M2E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_72,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 72, 32>; + using CLayout = GMMA::CLayout_64x72; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x72x32_F32E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x72x32_F32E5M2E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_72,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 72, 32>; + using CLayout = GMMA::CLayout_64x72; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x80x32_F16E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x80x32_F16E5M2E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 80, 32>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x80x32_F16E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x80x32_F16E5M2E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 80, 32>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x80x32_F32E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x80x32_F32E5M2E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 80, 32>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x80x32_F32E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x80x32_F32E5M2E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 80, 32>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x88x32_F16E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x88x32_F16E5M2E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_88,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 88, 32>; + using CLayout = GMMA::CLayout_64x88; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x88x32_F16E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x88x32_F16E5M2E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_88,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 88, 32>; + using CLayout = GMMA::CLayout_64x88; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x88x32_F32E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x88x32_F32E5M2E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_88,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 88, 32>; + using CLayout = GMMA::CLayout_64x88; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x88x32_F32E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x88x32_F32E5M2E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_88,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 88, 32>; + using CLayout = GMMA::CLayout_64x88; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x104x32_F16E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x104x32_F16E5M2E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_104,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<104, 32>; + using CLayout = GMMA::CLayout_64x104; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x104x32_F16E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x104x32_F16E5M2E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_104,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<104, 32>; + using CLayout = GMMA::CLayout_64x104; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x104x32_F32E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x104x32_F32E5M2E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_104,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<104, 32>; + using CLayout = GMMA::CLayout_64x104; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x104x32_F32E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x104x32_F32E5M2E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_104,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<104, 32>; + using CLayout = GMMA::CLayout_64x104; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x112x32_F16E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x112x32_F16E5M2E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<112, 32>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x112x32_F16E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x112x32_F16E5M2E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<112, 32>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x112x32_F32E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x112x32_F32E5M2E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<112, 32>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x112x32_F32E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x112x32_F32E5M2E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<112, 32>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x120x32_F16E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x120x32_F16E5M2E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_120,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<120, 32>; + using CLayout = GMMA::CLayout_64x120; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x120x32_F16E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x120x32_F16E5M2E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_120,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<120, 32>; + using CLayout = GMMA::CLayout_64x120; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x120x32_F32E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x120x32_F32E5M2E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_120,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<120, 32>; + using CLayout = GMMA::CLayout_64x120; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x120x32_F32E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x120x32_F32E5M2E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_120,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<120, 32>; + using CLayout = GMMA::CLayout_64x120; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x136x32_F16E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x136x32_F16E5M2E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_136,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<136, 32>; + using CLayout = GMMA::CLayout_64x136; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x136x32_F16E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x136x32_F16E5M2E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_136,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<136, 32>; + using CLayout = GMMA::CLayout_64x136; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x136x32_F32E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x136x32_F32E5M2E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_136,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<136, 32>; + using CLayout = GMMA::CLayout_64x136; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x136x32_F32E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x136x32_F32E5M2E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_136,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<136, 32>; + using CLayout = GMMA::CLayout_64x136; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x144x32_F16E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x144x32_F16E5M2E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<144, 32>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x144x32_F16E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x144x32_F16E5M2E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<144, 32>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x144x32_F32E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x144x32_F32E5M2E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<144, 32>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x144x32_F32E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x144x32_F32E5M2E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<144, 32>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x152x32_F16E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x152x32_F16E5M2E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_152,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<152, 32>; + using CLayout = GMMA::CLayout_64x152; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x152x32_F16E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x152x32_F16E5M2E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_152,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<152, 32>; + using CLayout = GMMA::CLayout_64x152; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x152x32_F32E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x152x32_F32E5M2E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_152,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<152, 32>; + using CLayout = GMMA::CLayout_64x152; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x152x32_F32E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x152x32_F32E5M2E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_152,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<152, 32>; + using CLayout = GMMA::CLayout_64x152; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x160x32_F16E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x160x32_F16E5M2E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<160, 32>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x160x32_F16E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x160x32_F16E5M2E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<160, 32>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x160x32_F32E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x160x32_F32E5M2E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<160, 32>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x160x32_F32E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x160x32_F32E5M2E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<160, 32>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x168x32_F16E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x168x32_F16E5M2E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_168,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<168, 32>; + using CLayout = GMMA::CLayout_64x168; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x168x32_F16E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x168x32_F16E5M2E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_168,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<168, 32>; + using CLayout = GMMA::CLayout_64x168; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x168x32_F32E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x168x32_F32E5M2E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_168,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<168, 32>; + using CLayout = GMMA::CLayout_64x168; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x168x32_F32E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x168x32_F32E5M2E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_168,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<168, 32>; + using CLayout = GMMA::CLayout_64x168; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x176x32_F16E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x176x32_F16E5M2E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<176, 32>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x176x32_F16E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x176x32_F16E5M2E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<176, 32>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x176x32_F32E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x176x32_F32E5M2E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<176, 32>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x176x32_F32E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x176x32_F32E5M2E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<176, 32>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x184x32_F16E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x184x32_F16E5M2E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_184,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<184, 32>; + using CLayout = GMMA::CLayout_64x184; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x184x32_F16E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x184x32_F16E5M2E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_184,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<184, 32>; + using CLayout = GMMA::CLayout_64x184; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x184x32_F32E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x184x32_F32E5M2E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_184,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<184, 32>; + using CLayout = GMMA::CLayout_64x184; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x184x32_F32E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x184x32_F32E5M2E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_184,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<184, 32>; + using CLayout = GMMA::CLayout_64x184; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x200x32_F16E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x200x32_F16E5M2E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_200,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<200, 32>; + using CLayout = GMMA::CLayout_64x200; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x200x32_F16E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x200x32_F16E5M2E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_200,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<200, 32>; + using CLayout = GMMA::CLayout_64x200; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x200x32_F32E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x200x32_F32E5M2E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_200,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<200, 32>; + using CLayout = GMMA::CLayout_64x200; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x200x32_F32E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x200x32_F32E5M2E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_200,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<200, 32>; + using CLayout = GMMA::CLayout_64x200; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x208x32_F16E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x208x32_F16E5M2E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<208, 32>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x208x32_F16E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x208x32_F16E5M2E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<208, 32>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x208x32_F32E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x208x32_F32E5M2E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<208, 32>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x208x32_F32E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x208x32_F32E5M2E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<208, 32>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x216x32_F16E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x216x32_F16E5M2E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_216,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<216, 32>; + using CLayout = GMMA::CLayout_64x216; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x216x32_F16E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x216x32_F16E5M2E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_216,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<216, 32>; + using CLayout = GMMA::CLayout_64x216; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x216x32_F32E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x216x32_F32E5M2E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_216,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<216, 32>; + using CLayout = GMMA::CLayout_64x216; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x216x32_F32E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x216x32_F32E5M2E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_216,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<216, 32>; + using CLayout = GMMA::CLayout_64x216; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x224x32_F16E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x224x32_F16E5M2E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<224, 32>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x224x32_F16E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x224x32_F16E5M2E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<224, 32>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x224x32_F32E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x224x32_F32E5M2E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<224, 32>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x224x32_F32E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x224x32_F32E5M2E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<224, 32>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x232x32_F16E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x232x32_F16E5M2E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_232,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<232, 32>; + using CLayout = GMMA::CLayout_64x232; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x232x32_F16E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x232x32_F16E5M2E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_232,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<232, 32>; + using CLayout = GMMA::CLayout_64x232; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x232x32_F32E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x232x32_F32E5M2E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_232,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<232, 32>; + using CLayout = GMMA::CLayout_64x232; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x232x32_F32E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x232x32_F32E5M2E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_232,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<232, 32>; + using CLayout = GMMA::CLayout_64x232; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x240x32_F16E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x240x32_F16E5M2E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<240, 32>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x240x32_F16E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x240x32_F16E5M2E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<240, 32>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x240x32_F32E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x240x32_F32E5M2E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<240, 32>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x240x32_F32E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x240x32_F32E5M2E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<240, 32>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x248x32_F16E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x248x32_F16E5M2E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_248,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<248, 32>; + using CLayout = GMMA::CLayout_64x248; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x248x32_F16E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x248x32_F16E5M2E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_248,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<248, 32>; + using CLayout = GMMA::CLayout_64x248; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x248x32_F32E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x248x32_F32E5M2E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_248,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<248, 32>; + using CLayout = GMMA::CLayout_64x248; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x248x32_F32E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x248x32_F32E5M2E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_248,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<248, 32>; + using CLayout = GMMA::CLayout_64x248; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x24x32_F16E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x24x32_F16E5M2E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 24, 32>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x24x32_F16E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x24x32_F16E5M2E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 24, 32>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x24x32_F32E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x24x32_F32E5M2E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 24, 32>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x24x32_F32E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x24x32_F32E5M2E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 24, 32>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x40x32_F16E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x40x32_F16E5M2E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_40,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 40, 32>; + using CLayout = GMMA::CLayout_64x40; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x40x32_F16E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x40x32_F16E5M2E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_40,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 40, 32>; + using CLayout = GMMA::CLayout_64x40; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x40x32_F32E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x40x32_F32E5M2E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_40,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 40, 32>; + using CLayout = GMMA::CLayout_64x40; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x40x32_F32E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x40x32_F32E5M2E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_40,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 40, 32>; + using CLayout = GMMA::CLayout_64x40; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x48x32_F16E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x48x32_F16E5M2E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 48, 32>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x48x32_F16E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x48x32_F16E5M2E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 48, 32>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x48x32_F32E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x48x32_F32E5M2E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 48, 32>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x48x32_F32E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x48x32_F32E5M2E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 48, 32>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x56x32_F16E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x56x32_F16E5M2E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_56,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 56, 32>; + using CLayout = GMMA::CLayout_64x56; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x56x32_F16E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x56x32_F16E5M2E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_56,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 56, 32>; + using CLayout = GMMA::CLayout_64x56; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x56x32_F32E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x56x32_F32E5M2E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_56,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 56, 32>; + using CLayout = GMMA::CLayout_64x56; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x56x32_F32E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x56x32_F32E5M2E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_56,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 56, 32>; + using CLayout = GMMA::CLayout_64x56; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x72x32_F16E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x72x32_F16E5M2E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_72,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 72, 32>; + using CLayout = GMMA::CLayout_64x72; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x72x32_F16E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x72x32_F16E5M2E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_72,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 72, 32>; + using CLayout = GMMA::CLayout_64x72; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x72x32_F32E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x72x32_F32E5M2E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_72,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 72, 32>; + using CLayout = GMMA::CLayout_64x72; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x72x32_F32E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x72x32_F32E5M2E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_72,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 72, 32>; + using CLayout = GMMA::CLayout_64x72; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x80x32_F16E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x80x32_F16E5M2E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 80, 32>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x80x32_F16E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x80x32_F16E5M2E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 80, 32>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x80x32_F32E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x80x32_F32E5M2E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 80, 32>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x80x32_F32E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x80x32_F32E5M2E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 80, 32>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x88x32_F16E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x88x32_F16E5M2E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_88,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 88, 32>; + using CLayout = GMMA::CLayout_64x88; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x88x32_F16E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x88x32_F16E5M2E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_88,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 88, 32>; + using CLayout = GMMA::CLayout_64x88; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x88x32_F32E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x88x32_F32E5M2E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_88,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 88, 32>; + using CLayout = GMMA::CLayout_64x88; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x88x32_F32E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x88x32_F32E5M2E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_88,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 88, 32>; + using CLayout = GMMA::CLayout_64x88; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x104x32_F16E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x104x32_F16E5M2E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_104,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<104, 32>; + using CLayout = GMMA::CLayout_64x104; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x104x32_F16E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x104x32_F16E5M2E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_104,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<104, 32>; + using CLayout = GMMA::CLayout_64x104; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x104x32_F32E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x104x32_F32E5M2E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_104,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<104, 32>; + using CLayout = GMMA::CLayout_64x104; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x104x32_F32E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x104x32_F32E5M2E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_104,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<104, 32>; + using CLayout = GMMA::CLayout_64x104; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x112x32_F16E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x112x32_F16E5M2E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<112, 32>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x112x32_F16E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x112x32_F16E5M2E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<112, 32>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x112x32_F32E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x112x32_F32E5M2E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<112, 32>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x112x32_F32E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x112x32_F32E5M2E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<112, 32>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x120x32_F16E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x120x32_F16E5M2E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_120,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<120, 32>; + using CLayout = GMMA::CLayout_64x120; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x120x32_F16E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x120x32_F16E5M2E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_120,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<120, 32>; + using CLayout = GMMA::CLayout_64x120; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x120x32_F32E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x120x32_F32E5M2E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_120,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<120, 32>; + using CLayout = GMMA::CLayout_64x120; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x120x32_F32E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x120x32_F32E5M2E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_120,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<120, 32>; + using CLayout = GMMA::CLayout_64x120; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x136x32_F16E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x136x32_F16E5M2E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_136,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<136, 32>; + using CLayout = GMMA::CLayout_64x136; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x136x32_F16E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x136x32_F16E5M2E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_136,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<136, 32>; + using CLayout = GMMA::CLayout_64x136; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x136x32_F32E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x136x32_F32E5M2E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_136,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<136, 32>; + using CLayout = GMMA::CLayout_64x136; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x136x32_F32E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x136x32_F32E5M2E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_136,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<136, 32>; + using CLayout = GMMA::CLayout_64x136; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x144x32_F16E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x144x32_F16E5M2E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<144, 32>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x144x32_F16E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x144x32_F16E5M2E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<144, 32>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x144x32_F32E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x144x32_F32E5M2E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<144, 32>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x144x32_F32E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x144x32_F32E5M2E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<144, 32>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x152x32_F16E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x152x32_F16E5M2E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_152,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<152, 32>; + using CLayout = GMMA::CLayout_64x152; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x152x32_F16E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x152x32_F16E5M2E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_152,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<152, 32>; + using CLayout = GMMA::CLayout_64x152; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x152x32_F32E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x152x32_F32E5M2E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_152,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<152, 32>; + using CLayout = GMMA::CLayout_64x152; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x152x32_F32E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x152x32_F32E5M2E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_152,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<152, 32>; + using CLayout = GMMA::CLayout_64x152; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x160x32_F16E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x160x32_F16E5M2E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<160, 32>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x160x32_F16E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x160x32_F16E5M2E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<160, 32>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x160x32_F32E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x160x32_F32E5M2E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<160, 32>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x160x32_F32E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x160x32_F32E5M2E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<160, 32>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x168x32_F16E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x168x32_F16E5M2E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_168,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<168, 32>; + using CLayout = GMMA::CLayout_64x168; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x168x32_F16E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x168x32_F16E5M2E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_168,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<168, 32>; + using CLayout = GMMA::CLayout_64x168; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x168x32_F32E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x168x32_F32E5M2E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_168,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<168, 32>; + using CLayout = GMMA::CLayout_64x168; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x168x32_F32E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x168x32_F32E5M2E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_168,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<168, 32>; + using CLayout = GMMA::CLayout_64x168; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x176x32_F16E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x176x32_F16E5M2E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<176, 32>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x176x32_F16E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x176x32_F16E5M2E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<176, 32>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x176x32_F32E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x176x32_F32E5M2E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<176, 32>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x176x32_F32E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x176x32_F32E5M2E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<176, 32>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x184x32_F16E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x184x32_F16E5M2E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_184,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<184, 32>; + using CLayout = GMMA::CLayout_64x184; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x184x32_F16E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x184x32_F16E5M2E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_184,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<184, 32>; + using CLayout = GMMA::CLayout_64x184; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x184x32_F32E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x184x32_F32E5M2E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_184,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<184, 32>; + using CLayout = GMMA::CLayout_64x184; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x184x32_F32E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x184x32_F32E5M2E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_184,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<184, 32>; + using CLayout = GMMA::CLayout_64x184; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x200x32_F16E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x200x32_F16E5M2E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_200,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<200, 32>; + using CLayout = GMMA::CLayout_64x200; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x200x32_F16E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x200x32_F16E5M2E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_200,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<200, 32>; + using CLayout = GMMA::CLayout_64x200; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x200x32_F32E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x200x32_F32E5M2E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_200,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<200, 32>; + using CLayout = GMMA::CLayout_64x200; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x200x32_F32E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x200x32_F32E5M2E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_200,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<200, 32>; + using CLayout = GMMA::CLayout_64x200; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x208x32_F16E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x208x32_F16E5M2E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<208, 32>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x208x32_F16E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x208x32_F16E5M2E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<208, 32>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x208x32_F32E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x208x32_F32E5M2E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<208, 32>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x208x32_F32E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x208x32_F32E5M2E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<208, 32>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x216x32_F16E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x216x32_F16E5M2E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_216,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<216, 32>; + using CLayout = GMMA::CLayout_64x216; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x216x32_F16E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x216x32_F16E5M2E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_216,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<216, 32>; + using CLayout = GMMA::CLayout_64x216; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x216x32_F32E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x216x32_F32E5M2E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_216,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<216, 32>; + using CLayout = GMMA::CLayout_64x216; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x216x32_F32E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x216x32_F32E5M2E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_216,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<216, 32>; + using CLayout = GMMA::CLayout_64x216; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x224x32_F16E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x224x32_F16E5M2E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<224, 32>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x224x32_F16E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x224x32_F16E5M2E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<224, 32>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x224x32_F32E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x224x32_F32E5M2E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<224, 32>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x224x32_F32E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x224x32_F32E5M2E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<224, 32>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x232x32_F16E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x232x32_F16E5M2E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_232,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<232, 32>; + using CLayout = GMMA::CLayout_64x232; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x232x32_F16E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x232x32_F16E5M2E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_232,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<232, 32>; + using CLayout = GMMA::CLayout_64x232; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x232x32_F32E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x232x32_F32E5M2E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_232,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<232, 32>; + using CLayout = GMMA::CLayout_64x232; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x232x32_F32E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x232x32_F32E5M2E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_232,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<232, 32>; + using CLayout = GMMA::CLayout_64x232; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x240x32_F16E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x240x32_F16E5M2E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<240, 32>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x240x32_F16E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x240x32_F16E5M2E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<240, 32>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x240x32_F32E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x240x32_F32E5M2E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<240, 32>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x240x32_F32E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x240x32_F32E5M2E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<240, 32>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x248x32_F16E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x248x32_F16E5M2E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_248,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<248, 32>; + using CLayout = GMMA::CLayout_64x248; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x248x32_F16E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x248x32_F16E5M2E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_248,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<248, 32>; + using CLayout = GMMA::CLayout_64x248; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x248x32_F32E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x248x32_F32E5M2E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_248,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<248, 32>; + using CLayout = GMMA::CLayout_64x248; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x248x32_F32E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x248x32_F32E5M2E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_248,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<248, 32>; + using CLayout = GMMA::CLayout_64x248; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // end namespace cute diff --git a/include/cute/atom/mma_traits_sm90_gmma_sparse.hpp b/include/cute/atom/mma_traits_sm90_gmma_sparse.hpp index 7252a0ef5..27c41ad33 100644 --- a/include/cute/atom/mma_traits_sm90_gmma_sparse.hpp +++ b/include/cute/atom/mma_traits_sm90_gmma_sparse.hpp @@ -352,57 +352,6 @@ struct MMA_Traits -struct MMA_Traits> -{ - using ValTypeD = half_t; - using ValTypeA = sparse_elem<2, half_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = half_t; - using ValTypeC = half_t; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_48,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 32>; - using ELayout = GMMA::ELayout_64x32; - using BLayout = GMMA::ABLayout< 48, 32>; - using CLayout = GMMA::CLayout_64x48; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template -struct MMA_Traits> -{ - using ValTypeD = half_t; - using ValTypeA = sparse_elem<2, half_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = half_t; - using ValTypeC = half_t; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_48,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x32; - using ELayout = GMMA::ELayout_64x32; - using BLayout = GMMA::ABLayout< 48, 32>; - using CLayout = GMMA::CLayout_64x48; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - template struct MMA_Traits> { @@ -450,9 +399,8 @@ struct MMA_Traits -struct MMA_Traits> +struct MMA_Traits> { using ValTypeD = half_t; using ValTypeA = sparse_elem<2, half_t>; @@ -463,22 +411,20 @@ struct MMA_Traits; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_80,_32>; + using Shape_MNK = Shape<_64,_96,_32>; using ThrID = Layout<_128>; using ALayout = GMMA::ABLayout< 64, 32>; using ELayout = GMMA::ELayout_64x32; - using BLayout = GMMA::ABLayout< 80, 32>; - using CLayout = GMMA::CLayout_64x80; + using BLayout = GMMA::ABLayout< 96, 32>; + using CLayout = GMMA::CLayout_64x96; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) template -struct MMA_Traits> +struct MMA_Traits> { using ValTypeD = half_t; using ValTypeA = sparse_elem<2, half_t>; @@ -488,21 +434,20 @@ struct MMA_Traits; - using Shape_MNK = Shape<_64,_80,_32>; + using Shape_MNK = Shape<_64,_96,_32>; using ThrID = Layout<_128>; using ALayout = GMMA::ALayout_64x32; using ELayout = GMMA::ELayout_64x32; - using BLayout = GMMA::ABLayout< 80, 32>; - using CLayout = GMMA::CLayout_64x80; + using BLayout = GMMA::ABLayout< 96, 32>; + using CLayout = GMMA::CLayout_64x96; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// template -struct MMA_Traits> +struct MMA_Traits> { using ValTypeD = half_t; using ValTypeA = sparse_elem<2, half_t>; @@ -513,12 +458,12 @@ struct MMA_Traits; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_96,_32>; + using Shape_MNK = Shape<_64,_128,_32>; using ThrID = Layout<_128>; using ALayout = GMMA::ABLayout< 64, 32>; using ELayout = GMMA::ELayout_64x32; - using BLayout = GMMA::ABLayout< 96, 32>; - using CLayout = GMMA::CLayout_64x96; + using BLayout = GMMA::ABLayout<128, 32>; + using CLayout = GMMA::CLayout_64x128; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; @@ -526,7 +471,7 @@ struct MMA_Traits -struct MMA_Traits> +struct MMA_Traits> { using ValTypeD = half_t; using ValTypeA = sparse_elem<2, half_t>; @@ -536,21 +481,20 @@ struct MMA_Traits; - using Shape_MNK = Shape<_64,_96,_32>; + using Shape_MNK = Shape<_64,_128,_32>; using ThrID = Layout<_128>; using ALayout = GMMA::ALayout_64x32; using ELayout = GMMA::ELayout_64x32; - using BLayout = GMMA::ABLayout< 96, 32>; - using CLayout = GMMA::CLayout_64x96; + using BLayout = GMMA::ABLayout<128, 32>; + using CLayout = GMMA::CLayout_64x128; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) template -struct MMA_Traits> +struct MMA_Traits> { using ValTypeD = half_t; using ValTypeA = sparse_elem<2, half_t>; @@ -561,22 +505,20 @@ struct MMA_Traits; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_112,_32>; + using Shape_MNK = Shape<_64,_192,_32>; using ThrID = Layout<_128>; using ALayout = GMMA::ABLayout< 64, 32>; using ELayout = GMMA::ELayout_64x32; - using BLayout = GMMA::ABLayout<112, 32>; - using CLayout = GMMA::CLayout_64x112; + using BLayout = GMMA::ABLayout<192, 32>; + using CLayout = GMMA::CLayout_64x192; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) template -struct MMA_Traits> +struct MMA_Traits> { using ValTypeD = half_t; using ValTypeA = sparse_elem<2, half_t>; @@ -586,21 +528,20 @@ struct MMA_Traits; - using Shape_MNK = Shape<_64,_112,_32>; + using Shape_MNK = Shape<_64,_192,_32>; using ThrID = Layout<_128>; using ALayout = GMMA::ALayout_64x32; using ELayout = GMMA::ELayout_64x32; - using BLayout = GMMA::ABLayout<112, 32>; - using CLayout = GMMA::CLayout_64x112; + using BLayout = GMMA::ABLayout<192, 32>; + using CLayout = GMMA::CLayout_64x192; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// template -struct MMA_Traits> +struct MMA_Traits> { using ValTypeD = half_t; using ValTypeA = sparse_elem<2, half_t>; @@ -611,12 +552,12 @@ struct MMA_Traits; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_128,_32>; + using Shape_MNK = Shape<_64,_256,_32>; using ThrID = Layout<_128>; using ALayout = GMMA::ABLayout< 64, 32>; using ELayout = GMMA::ELayout_64x32; - using BLayout = GMMA::ABLayout<128, 32>; - using CLayout = GMMA::CLayout_64x128; + using BLayout = GMMA::ABLayout<256, 32>; + using CLayout = GMMA::CLayout_64x256; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; @@ -624,7 +565,7 @@ struct MMA_Traits -struct MMA_Traits> +struct MMA_Traits> { using ValTypeD = half_t; using ValTypeA = sparse_elem<2, half_t>; @@ -634,189 +575,177 @@ struct MMA_Traits; - using Shape_MNK = Shape<_64,_128,_32>; + using Shape_MNK = Shape<_64,_256,_32>; using ThrID = Layout<_128>; using ALayout = GMMA::ALayout_64x32; using ELayout = GMMA::ELayout_64x32; - using BLayout = GMMA::ABLayout<128, 32>; - using CLayout = GMMA::CLayout_64x128; + using BLayout = GMMA::ABLayout<256, 32>; + using CLayout = GMMA::CLayout_64x256; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) template -struct MMA_Traits> +struct MMA_Traits> { - using ValTypeD = half_t; + using ValTypeD = float; using ValTypeA = sparse_elem<2, half_t>; using ValTypeE = sparse_elem<8, uint8_t>; using ValTypeB = half_t; - using ValTypeC = half_t; + using ValTypeC = float; using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_144,_32>; + using Shape_MNK = Shape<_64,_8,_32>; using ThrID = Layout<_128>; using ALayout = GMMA::ABLayout< 64, 32>; using ELayout = GMMA::ELayout_64x32; - using BLayout = GMMA::ABLayout<144, 32>; - using CLayout = GMMA::CLayout_64x144; + using BLayout = GMMA::ABLayout< 8, 32>; + using CLayout = GMMA::CLayout_64x8; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) template -struct MMA_Traits> +struct MMA_Traits> { - using ValTypeD = half_t; + using ValTypeD = float; using ValTypeA = sparse_elem<2, half_t>; using ValTypeE = sparse_elem<8, uint8_t>; using ValTypeB = half_t; - using ValTypeC = half_t; + using ValTypeC = float; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_144,_32>; + using Shape_MNK = Shape<_64,_8,_32>; using ThrID = Layout<_128>; using ALayout = GMMA::ALayout_64x32; using ELayout = GMMA::ELayout_64x32; - using BLayout = GMMA::ABLayout<144, 32>; - using CLayout = GMMA::CLayout_64x144; + using BLayout = GMMA::ABLayout< 8, 32>; + using CLayout = GMMA::CLayout_64x8; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) template -struct MMA_Traits> +struct MMA_Traits> { - using ValTypeD = half_t; + using ValTypeD = float; using ValTypeA = sparse_elem<2, half_t>; using ValTypeE = sparse_elem<8, uint8_t>; using ValTypeB = half_t; - using ValTypeC = half_t; + using ValTypeC = float; using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_160,_32>; + using Shape_MNK = Shape<_64,_16,_32>; using ThrID = Layout<_128>; using ALayout = GMMA::ABLayout< 64, 32>; using ELayout = GMMA::ELayout_64x32; - using BLayout = GMMA::ABLayout<160, 32>; - using CLayout = GMMA::CLayout_64x160; + using BLayout = GMMA::ABLayout< 16, 32>; + using CLayout = GMMA::CLayout_64x16; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) template -struct MMA_Traits> +struct MMA_Traits> { - using ValTypeD = half_t; + using ValTypeD = float; using ValTypeA = sparse_elem<2, half_t>; using ValTypeE = sparse_elem<8, uint8_t>; using ValTypeB = half_t; - using ValTypeC = half_t; + using ValTypeC = float; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_160,_32>; + using Shape_MNK = Shape<_64,_16,_32>; using ThrID = Layout<_128>; using ALayout = GMMA::ALayout_64x32; using ELayout = GMMA::ELayout_64x32; - using BLayout = GMMA::ABLayout<160, 32>; - using CLayout = GMMA::CLayout_64x160; + using BLayout = GMMA::ABLayout< 16, 32>; + using CLayout = GMMA::CLayout_64x16; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) template -struct MMA_Traits> +struct MMA_Traits> { - using ValTypeD = half_t; + using ValTypeD = float; using ValTypeA = sparse_elem<2, half_t>; using ValTypeE = sparse_elem<8, uint8_t>; using ValTypeB = half_t; - using ValTypeC = half_t; + using ValTypeC = float; using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_176,_32>; + using Shape_MNK = Shape<_64,_32,_32>; using ThrID = Layout<_128>; using ALayout = GMMA::ABLayout< 64, 32>; using ELayout = GMMA::ELayout_64x32; - using BLayout = GMMA::ABLayout<176, 32>; - using CLayout = GMMA::CLayout_64x176; + using BLayout = GMMA::ABLayout< 32, 32>; + using CLayout = GMMA::CLayout_64x32; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) template -struct MMA_Traits> +struct MMA_Traits> { - using ValTypeD = half_t; + using ValTypeD = float; using ValTypeA = sparse_elem<2, half_t>; using ValTypeE = sparse_elem<8, uint8_t>; using ValTypeB = half_t; - using ValTypeC = half_t; + using ValTypeC = float; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_176,_32>; + using Shape_MNK = Shape<_64,_32,_32>; using ThrID = Layout<_128>; using ALayout = GMMA::ALayout_64x32; using ELayout = GMMA::ELayout_64x32; - using BLayout = GMMA::ABLayout<176, 32>; - using CLayout = GMMA::CLayout_64x176; + using BLayout = GMMA::ABLayout< 32, 32>; + using CLayout = GMMA::CLayout_64x32; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// template -struct MMA_Traits> +struct MMA_Traits> { - using ValTypeD = half_t; + using ValTypeD = float; using ValTypeA = sparse_elem<2, half_t>; using ValTypeE = sparse_elem<8, uint8_t>; using ValTypeB = half_t; - using ValTypeC = half_t; + using ValTypeC = float; using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_192,_32>; + using Shape_MNK = Shape<_64,_64,_32>; using ThrID = Layout<_128>; using ALayout = GMMA::ABLayout< 64, 32>; using ELayout = GMMA::ELayout_64x32; - using BLayout = GMMA::ABLayout<192, 32>; - using CLayout = GMMA::CLayout_64x192; + using BLayout = GMMA::ABLayout< 64, 32>; + using CLayout = GMMA::CLayout_64x64; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; @@ -824,189 +753,177 @@ struct MMA_Traits -struct MMA_Traits> +struct MMA_Traits> { - using ValTypeD = half_t; + using ValTypeD = float; using ValTypeA = sparse_elem<2, half_t>; using ValTypeE = sparse_elem<8, uint8_t>; using ValTypeB = half_t; - using ValTypeC = half_t; + using ValTypeC = float; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_192,_32>; + using Shape_MNK = Shape<_64,_64,_32>; using ThrID = Layout<_128>; using ALayout = GMMA::ALayout_64x32; using ELayout = GMMA::ELayout_64x32; - using BLayout = GMMA::ABLayout<192, 32>; - using CLayout = GMMA::CLayout_64x192; + using BLayout = GMMA::ABLayout< 64, 32>; + using CLayout = GMMA::CLayout_64x64; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) template -struct MMA_Traits> +struct MMA_Traits> { - using ValTypeD = half_t; + using ValTypeD = float; using ValTypeA = sparse_elem<2, half_t>; using ValTypeE = sparse_elem<8, uint8_t>; using ValTypeB = half_t; - using ValTypeC = half_t; + using ValTypeC = float; using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_208,_32>; + using Shape_MNK = Shape<_64,_96,_32>; using ThrID = Layout<_128>; using ALayout = GMMA::ABLayout< 64, 32>; using ELayout = GMMA::ELayout_64x32; - using BLayout = GMMA::ABLayout<208, 32>; - using CLayout = GMMA::CLayout_64x208; + using BLayout = GMMA::ABLayout< 96, 32>; + using CLayout = GMMA::CLayout_64x96; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) template -struct MMA_Traits> +struct MMA_Traits> { - using ValTypeD = half_t; + using ValTypeD = float; using ValTypeA = sparse_elem<2, half_t>; using ValTypeE = sparse_elem<8, uint8_t>; using ValTypeB = half_t; - using ValTypeC = half_t; + using ValTypeC = float; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_208,_32>; + using Shape_MNK = Shape<_64,_96,_32>; using ThrID = Layout<_128>; using ALayout = GMMA::ALayout_64x32; using ELayout = GMMA::ELayout_64x32; - using BLayout = GMMA::ABLayout<208, 32>; - using CLayout = GMMA::CLayout_64x208; + using BLayout = GMMA::ABLayout< 96, 32>; + using CLayout = GMMA::CLayout_64x96; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) template -struct MMA_Traits> +struct MMA_Traits> { - using ValTypeD = half_t; + using ValTypeD = float; using ValTypeA = sparse_elem<2, half_t>; using ValTypeE = sparse_elem<8, uint8_t>; using ValTypeB = half_t; - using ValTypeC = half_t; + using ValTypeC = float; using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_224,_32>; + using Shape_MNK = Shape<_64,_128,_32>; using ThrID = Layout<_128>; using ALayout = GMMA::ABLayout< 64, 32>; using ELayout = GMMA::ELayout_64x32; - using BLayout = GMMA::ABLayout<224, 32>; - using CLayout = GMMA::CLayout_64x224; + using BLayout = GMMA::ABLayout<128, 32>; + using CLayout = GMMA::CLayout_64x128; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) template -struct MMA_Traits> +struct MMA_Traits> { - using ValTypeD = half_t; + using ValTypeD = float; using ValTypeA = sparse_elem<2, half_t>; using ValTypeE = sparse_elem<8, uint8_t>; using ValTypeB = half_t; - using ValTypeC = half_t; + using ValTypeC = float; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_224,_32>; + using Shape_MNK = Shape<_64,_128,_32>; using ThrID = Layout<_128>; using ALayout = GMMA::ALayout_64x32; using ELayout = GMMA::ELayout_64x32; - using BLayout = GMMA::ABLayout<224, 32>; - using CLayout = GMMA::CLayout_64x224; + using BLayout = GMMA::ABLayout<128, 32>; + using CLayout = GMMA::CLayout_64x128; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) template -struct MMA_Traits> +struct MMA_Traits> { - using ValTypeD = half_t; + using ValTypeD = float; using ValTypeA = sparse_elem<2, half_t>; using ValTypeE = sparse_elem<8, uint8_t>; using ValTypeB = half_t; - using ValTypeC = half_t; + using ValTypeC = float; using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_240,_32>; + using Shape_MNK = Shape<_64,_192,_32>; using ThrID = Layout<_128>; using ALayout = GMMA::ABLayout< 64, 32>; using ELayout = GMMA::ELayout_64x32; - using BLayout = GMMA::ABLayout<240, 32>; - using CLayout = GMMA::CLayout_64x240; + using BLayout = GMMA::ABLayout<192, 32>; + using CLayout = GMMA::CLayout_64x192; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) template -struct MMA_Traits> +struct MMA_Traits> { - using ValTypeD = half_t; + using ValTypeD = float; using ValTypeA = sparse_elem<2, half_t>; using ValTypeE = sparse_elem<8, uint8_t>; using ValTypeB = half_t; - using ValTypeC = half_t; + using ValTypeC = float; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_240,_32>; + using Shape_MNK = Shape<_64,_192,_32>; using ThrID = Layout<_128>; using ALayout = GMMA::ALayout_64x32; using ELayout = GMMA::ELayout_64x32; - using BLayout = GMMA::ABLayout<240, 32>; - using CLayout = GMMA::CLayout_64x240; + using BLayout = GMMA::ABLayout<192, 32>; + using CLayout = GMMA::CLayout_64x192; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// template -struct MMA_Traits> +struct MMA_Traits> { - using ValTypeD = half_t; + using ValTypeD = float; using ValTypeA = sparse_elem<2, half_t>; using ValTypeE = sparse_elem<8, uint8_t>; using ValTypeB = half_t; - using ValTypeC = half_t; + using ValTypeC = float; using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; @@ -1024,13 +941,13 @@ struct MMA_Traits -struct MMA_Traits> +struct MMA_Traits> { - using ValTypeD = half_t; + using ValTypeD = float; using ValTypeA = sparse_elem<2, half_t>; using ValTypeE = sparse_elem<8, uint8_t>; using ValTypeB = half_t; - using ValTypeC = half_t; + using ValTypeC = float; using FrgTypeB = GMMA::smem_desc; @@ -1047,12 +964,12 @@ struct MMA_Traits -struct MMA_Traits> +struct MMA_Traits> { using ValTypeD = float; - using ValTypeA = sparse_elem<2, half_t>; + using ValTypeA = sparse_elem<2, bfloat16_t>; using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = half_t; + using ValTypeB = bfloat16_t; using ValTypeC = float; using FrgTypeA = GMMA::smem_desc; @@ -1071,12 +988,12 @@ struct MMA_Traits -struct MMA_Traits> +struct MMA_Traits> { using ValTypeD = float; - using ValTypeA = sparse_elem<2, half_t>; + using ValTypeA = sparse_elem<2, bfloat16_t>; using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = half_t; + using ValTypeB = bfloat16_t; using ValTypeC = float; using FrgTypeB = GMMA::smem_desc; @@ -1094,12 +1011,12 @@ struct MMA_Traits -struct MMA_Traits> +struct MMA_Traits> { using ValTypeD = float; - using ValTypeA = sparse_elem<2, half_t>; + using ValTypeA = sparse_elem<2, bfloat16_t>; using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = half_t; + using ValTypeB = bfloat16_t; using ValTypeC = float; using FrgTypeA = GMMA::smem_desc; @@ -1118,12 +1035,12 @@ struct MMA_Traits -struct MMA_Traits> +struct MMA_Traits> { using ValTypeD = float; - using ValTypeA = sparse_elem<2, half_t>; + using ValTypeA = sparse_elem<2, bfloat16_t>; using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = half_t; + using ValTypeB = bfloat16_t; using ValTypeC = float; using FrgTypeB = GMMA::smem_desc; @@ -1141,12 +1058,12 @@ struct MMA_Traits -struct MMA_Traits> +struct MMA_Traits> { using ValTypeD = float; - using ValTypeA = sparse_elem<2, half_t>; + using ValTypeA = sparse_elem<2, bfloat16_t>; using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = half_t; + using ValTypeB = bfloat16_t; using ValTypeC = float; using FrgTypeA = GMMA::smem_desc; @@ -1165,12 +1082,12 @@ struct MMA_Traits -struct MMA_Traits> +struct MMA_Traits> { using ValTypeD = float; - using ValTypeA = sparse_elem<2, half_t>; + using ValTypeA = sparse_elem<2, bfloat16_t>; using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = half_t; + using ValTypeB = bfloat16_t; using ValTypeC = float; using FrgTypeB = GMMA::smem_desc; @@ -1187,75 +1104,71 @@ struct MMA_Traits -struct MMA_Traits> +struct MMA_Traits> { using ValTypeD = float; - using ValTypeA = sparse_elem<2, half_t>; + using ValTypeA = sparse_elem<2, bfloat16_t>; using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = half_t; + using ValTypeB = bfloat16_t; using ValTypeC = float; using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_48,_32>; + using Shape_MNK = Shape<_64,_64,_32>; using ThrID = Layout<_128>; using ALayout = GMMA::ABLayout< 64, 32>; using ELayout = GMMA::ELayout_64x32; - using BLayout = GMMA::ABLayout< 48, 32>; - using CLayout = GMMA::CLayout_64x48; + using BLayout = GMMA::ABLayout< 64, 32>; + using CLayout = GMMA::CLayout_64x64; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) template -struct MMA_Traits> +struct MMA_Traits> { using ValTypeD = float; - using ValTypeA = sparse_elem<2, half_t>; + using ValTypeA = sparse_elem<2, bfloat16_t>; using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = half_t; + using ValTypeB = bfloat16_t; using ValTypeC = float; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_48,_32>; + using Shape_MNK = Shape<_64,_64,_32>; using ThrID = Layout<_128>; using ALayout = GMMA::ALayout_64x32; using ELayout = GMMA::ELayout_64x32; - using BLayout = GMMA::ABLayout< 48, 32>; - using CLayout = GMMA::CLayout_64x48; + using BLayout = GMMA::ABLayout< 64, 32>; + using CLayout = GMMA::CLayout_64x64; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// template -struct MMA_Traits> +struct MMA_Traits> { using ValTypeD = float; - using ValTypeA = sparse_elem<2, half_t>; + using ValTypeA = sparse_elem<2, bfloat16_t>; using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = half_t; + using ValTypeB = bfloat16_t; using ValTypeC = float; using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_64,_32>; + using Shape_MNK = Shape<_64,_96,_32>; using ThrID = Layout<_128>; using ALayout = GMMA::ABLayout< 64, 32>; using ELayout = GMMA::ELayout_64x32; - using BLayout = GMMA::ABLayout< 64, 32>; - using CLayout = GMMA::CLayout_64x64; + using BLayout = GMMA::ABLayout< 96, 32>; + using CLayout = GMMA::CLayout_64x96; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; @@ -1263,97 +1176,93 @@ struct MMA_Traits -struct MMA_Traits> +struct MMA_Traits> { using ValTypeD = float; - using ValTypeA = sparse_elem<2, half_t>; + using ValTypeA = sparse_elem<2, bfloat16_t>; using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = half_t; + using ValTypeB = bfloat16_t; using ValTypeC = float; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_64,_32>; + using Shape_MNK = Shape<_64,_96,_32>; using ThrID = Layout<_128>; using ALayout = GMMA::ALayout_64x32; using ELayout = GMMA::ELayout_64x32; - using BLayout = GMMA::ABLayout< 64, 32>; - using CLayout = GMMA::CLayout_64x64; + using BLayout = GMMA::ABLayout< 96, 32>; + using CLayout = GMMA::CLayout_64x96; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) template -struct MMA_Traits> +struct MMA_Traits> { using ValTypeD = float; - using ValTypeA = sparse_elem<2, half_t>; + using ValTypeA = sparse_elem<2, bfloat16_t>; using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = half_t; + using ValTypeB = bfloat16_t; using ValTypeC = float; using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_80,_32>; + using Shape_MNK = Shape<_64,_128,_32>; using ThrID = Layout<_128>; using ALayout = GMMA::ABLayout< 64, 32>; using ELayout = GMMA::ELayout_64x32; - using BLayout = GMMA::ABLayout< 80, 32>; - using CLayout = GMMA::CLayout_64x80; + using BLayout = GMMA::ABLayout<128, 32>; + using CLayout = GMMA::CLayout_64x128; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) template -struct MMA_Traits> +struct MMA_Traits> { using ValTypeD = float; - using ValTypeA = sparse_elem<2, half_t>; + using ValTypeA = sparse_elem<2, bfloat16_t>; using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = half_t; + using ValTypeB = bfloat16_t; using ValTypeC = float; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_80,_32>; + using Shape_MNK = Shape<_64,_128,_32>; using ThrID = Layout<_128>; using ALayout = GMMA::ALayout_64x32; using ELayout = GMMA::ELayout_64x32; - using BLayout = GMMA::ABLayout< 80, 32>; - using CLayout = GMMA::CLayout_64x80; + using BLayout = GMMA::ABLayout<128, 32>; + using CLayout = GMMA::CLayout_64x128; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// template -struct MMA_Traits> +struct MMA_Traits> { using ValTypeD = float; - using ValTypeA = sparse_elem<2, half_t>; + using ValTypeA = sparse_elem<2, bfloat16_t>; using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = half_t; + using ValTypeB = bfloat16_t; using ValTypeC = float; using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_96,_32>; + using Shape_MNK = Shape<_64,_192,_32>; using ThrID = Layout<_128>; using ALayout = GMMA::ABLayout< 64, 32>; using ELayout = GMMA::ELayout_64x32; - using BLayout = GMMA::ABLayout< 96, 32>; - using CLayout = GMMA::CLayout_64x96; + using BLayout = GMMA::ABLayout<192, 32>; + using CLayout = GMMA::CLayout_64x192; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; @@ -1361,8870 +1270,257 @@ struct MMA_Traits -struct MMA_Traits> +struct MMA_Traits> { using ValTypeD = float; - using ValTypeA = sparse_elem<2, half_t>; + using ValTypeA = sparse_elem<2, bfloat16_t>; using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = half_t; + using ValTypeB = bfloat16_t; using ValTypeC = float; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_96,_32>; + using Shape_MNK = Shape<_64,_192,_32>; using ThrID = Layout<_128>; using ALayout = GMMA::ALayout_64x32; using ELayout = GMMA::ELayout_64x32; - using BLayout = GMMA::ABLayout< 96, 32>; - using CLayout = GMMA::CLayout_64x96; + using BLayout = GMMA::ABLayout<192, 32>; + using CLayout = GMMA::CLayout_64x192; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) template -struct MMA_Traits> +struct MMA_Traits> { using ValTypeD = float; - using ValTypeA = sparse_elem<2, half_t>; + using ValTypeA = sparse_elem<2, bfloat16_t>; using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = half_t; + using ValTypeB = bfloat16_t; using ValTypeC = float; using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_112,_32>; + using Shape_MNK = Shape<_64,_256,_32>; using ThrID = Layout<_128>; using ALayout = GMMA::ABLayout< 64, 32>; using ELayout = GMMA::ELayout_64x32; - using BLayout = GMMA::ABLayout<112, 32>; - using CLayout = GMMA::CLayout_64x112; + using BLayout = GMMA::ABLayout<256, 32>; + using CLayout = GMMA::CLayout_64x256; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) template -struct MMA_Traits> +struct MMA_Traits> { using ValTypeD = float; - using ValTypeA = sparse_elem<2, half_t>; + using ValTypeA = sparse_elem<2, bfloat16_t>; using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = half_t; + using ValTypeB = bfloat16_t; using ValTypeC = float; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_112,_32>; + using Shape_MNK = Shape<_64,_256,_32>; using ThrID = Layout<_128>; using ALayout = GMMA::ALayout_64x32; using ELayout = GMMA::ELayout_64x32; - using BLayout = GMMA::ABLayout<112, 32>; - using CLayout = GMMA::CLayout_64x112; + using BLayout = GMMA::ABLayout<256, 32>; + using CLayout = GMMA::CLayout_64x256; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -template -struct MMA_Traits> +template +struct MMA_Traits> { using ValTypeD = float; - using ValTypeA = sparse_elem<2, half_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = half_t; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; using ValTypeC = float; - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_128,_32>; + using Shape_MNK = Shape<_64,_8,_16>; using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 32>; - using ELayout = GMMA::ELayout_64x32; - using BLayout = GMMA::ABLayout<128, 32>; - using CLayout = GMMA::CLayout_64x128; + using ALayout = GMMA::ABLayout< 64, 16>; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout< 8, 16>; + using CLayout = GMMA::CLayout_64x8; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; //////////////////////////////////////////////////////////////////////////////////////////////////// -template -struct MMA_Traits> +template +struct MMA_Traits> { using ValTypeD = float; - using ValTypeA = sparse_elem<2, half_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = half_t; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; using ValTypeC = float; - using FrgTypeB = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_128,_32>; + using Shape_MNK = Shape<_64,_8,_16>; using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x32; - using ELayout = GMMA::ELayout_64x32; - using BLayout = GMMA::ABLayout<128, 32>; - using CLayout = GMMA::CLayout_64x128; + using ALayout = GMMA::ALayout_64x16; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout< 8, 16>; + using CLayout = GMMA::CLayout_64x8; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template -struct MMA_Traits> +template +struct MMA_Traits> { using ValTypeD = float; - using ValTypeA = sparse_elem<2, half_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = half_t; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; using ValTypeC = float; - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_144,_32>; + using Shape_MNK = Shape<_64,_16,_16>; using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 32>; - using ELayout = GMMA::ELayout_64x32; - using BLayout = GMMA::ABLayout<144, 32>; - using CLayout = GMMA::CLayout_64x144; + using ALayout = GMMA::ABLayout< 64, 16>; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout< 16, 16>; + using CLayout = GMMA::CLayout_64x16; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template -struct MMA_Traits> +template +struct MMA_Traits> { using ValTypeD = float; - using ValTypeA = sparse_elem<2, half_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = half_t; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; using ValTypeC = float; - using FrgTypeB = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_144,_32>; + using Shape_MNK = Shape<_64,_16,_16>; using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x32; - using ELayout = GMMA::ELayout_64x32; - using BLayout = GMMA::ABLayout<144, 32>; - using CLayout = GMMA::CLayout_64x144; + using ALayout = GMMA::ALayout_64x16; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout< 16, 16>; + using CLayout = GMMA::CLayout_64x16; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template -struct MMA_Traits> +template +struct MMA_Traits> { using ValTypeD = float; - using ValTypeA = sparse_elem<2, half_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = half_t; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; using ValTypeC = float; - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_160,_32>; + using Shape_MNK = Shape<_64,_32,_16>; using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 32>; - using ELayout = GMMA::ELayout_64x32; - using BLayout = GMMA::ABLayout<160, 32>; - using CLayout = GMMA::CLayout_64x160; + using ALayout = GMMA::ABLayout< 64, 16>; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout< 32, 16>; + using CLayout = GMMA::CLayout_64x32; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template -struct MMA_Traits> +template +struct MMA_Traits> { using ValTypeD = float; - using ValTypeA = sparse_elem<2, half_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = half_t; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; using ValTypeC = float; - using FrgTypeB = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_160,_32>; + using Shape_MNK = Shape<_64,_32,_16>; using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x32; - using ELayout = GMMA::ELayout_64x32; - using BLayout = GMMA::ABLayout<160, 32>; - using CLayout = GMMA::CLayout_64x160; + using ALayout = GMMA::ALayout_64x16; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout< 32, 16>; + using CLayout = GMMA::CLayout_64x32; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template -struct MMA_Traits> +template +struct MMA_Traits> { using ValTypeD = float; - using ValTypeA = sparse_elem<2, half_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = half_t; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; using ValTypeC = float; - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_176,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 32>; - using ELayout = GMMA::ELayout_64x32; - using BLayout = GMMA::ABLayout<176, 32>; - using CLayout = GMMA::CLayout_64x176; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template -struct MMA_Traits> -{ - using ValTypeD = float; - using ValTypeA = sparse_elem<2, half_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = half_t; - using ValTypeC = float; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_176,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x32; - using ELayout = GMMA::ELayout_64x32; - using BLayout = GMMA::ABLayout<176, 32>; - using CLayout = GMMA::CLayout_64x176; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct MMA_Traits> -{ - using ValTypeD = float; - using ValTypeA = sparse_elem<2, half_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = half_t; - using ValTypeC = float; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_192,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 32>; - using ELayout = GMMA::ELayout_64x32; - using BLayout = GMMA::ABLayout<192, 32>; - using CLayout = GMMA::CLayout_64x192; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct MMA_Traits> -{ - using ValTypeD = float; - using ValTypeA = sparse_elem<2, half_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = half_t; - using ValTypeC = float; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_192,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x32; - using ELayout = GMMA::ELayout_64x32; - using BLayout = GMMA::ABLayout<192, 32>; - using CLayout = GMMA::CLayout_64x192; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template -struct MMA_Traits> -{ - using ValTypeD = float; - using ValTypeA = sparse_elem<2, half_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = half_t; - using ValTypeC = float; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_208,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 32>; - using ELayout = GMMA::ELayout_64x32; - using BLayout = GMMA::ABLayout<208, 32>; - using CLayout = GMMA::CLayout_64x208; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template -struct MMA_Traits> -{ - using ValTypeD = float; - using ValTypeA = sparse_elem<2, half_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = half_t; - using ValTypeC = float; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_208,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x32; - using ELayout = GMMA::ELayout_64x32; - using BLayout = GMMA::ABLayout<208, 32>; - using CLayout = GMMA::CLayout_64x208; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template -struct MMA_Traits> -{ - using ValTypeD = float; - using ValTypeA = sparse_elem<2, half_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = half_t; - using ValTypeC = float; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_224,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 32>; - using ELayout = GMMA::ELayout_64x32; - using BLayout = GMMA::ABLayout<224, 32>; - using CLayout = GMMA::CLayout_64x224; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template -struct MMA_Traits> -{ - using ValTypeD = float; - using ValTypeA = sparse_elem<2, half_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = half_t; - using ValTypeC = float; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_224,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x32; - using ELayout = GMMA::ELayout_64x32; - using BLayout = GMMA::ABLayout<224, 32>; - using CLayout = GMMA::CLayout_64x224; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template -struct MMA_Traits> -{ - using ValTypeD = float; - using ValTypeA = sparse_elem<2, half_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = half_t; - using ValTypeC = float; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_240,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 32>; - using ELayout = GMMA::ELayout_64x32; - using BLayout = GMMA::ABLayout<240, 32>; - using CLayout = GMMA::CLayout_64x240; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template -struct MMA_Traits> -{ - using ValTypeD = float; - using ValTypeA = sparse_elem<2, half_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = half_t; - using ValTypeC = float; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_240,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x32; - using ELayout = GMMA::ELayout_64x32; - using BLayout = GMMA::ABLayout<240, 32>; - using CLayout = GMMA::CLayout_64x240; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct MMA_Traits> -{ - using ValTypeD = float; - using ValTypeA = sparse_elem<2, half_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = half_t; - using ValTypeC = float; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_256,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 32>; - using ELayout = GMMA::ELayout_64x32; - using BLayout = GMMA::ABLayout<256, 32>; - using CLayout = GMMA::CLayout_64x256; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct MMA_Traits> -{ - using ValTypeD = float; - using ValTypeA = sparse_elem<2, half_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = half_t; - using ValTypeC = float; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_256,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x32; - using ELayout = GMMA::ELayout_64x32; - using BLayout = GMMA::ABLayout<256, 32>; - using CLayout = GMMA::CLayout_64x256; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct MMA_Traits> -{ - using ValTypeD = float; - using ValTypeA = sparse_elem<2, bfloat16_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = bfloat16_t; - using ValTypeC = float; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_8,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 32>; - using ELayout = GMMA::ELayout_64x32; - using BLayout = GMMA::ABLayout< 8, 32>; - using CLayout = GMMA::CLayout_64x8; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct MMA_Traits> -{ - using ValTypeD = float; - using ValTypeA = sparse_elem<2, bfloat16_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = bfloat16_t; - using ValTypeC = float; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_8,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x32; - using ELayout = GMMA::ELayout_64x32; - using BLayout = GMMA::ABLayout< 8, 32>; - using CLayout = GMMA::CLayout_64x8; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct MMA_Traits> -{ - using ValTypeD = float; - using ValTypeA = sparse_elem<2, bfloat16_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = bfloat16_t; - using ValTypeC = float; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_16,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 32>; - using ELayout = GMMA::ELayout_64x32; - using BLayout = GMMA::ABLayout< 16, 32>; - using CLayout = GMMA::CLayout_64x16; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct MMA_Traits> -{ - using ValTypeD = float; - using ValTypeA = sparse_elem<2, bfloat16_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = bfloat16_t; - using ValTypeC = float; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_16,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x32; - using ELayout = GMMA::ELayout_64x32; - using BLayout = GMMA::ABLayout< 16, 32>; - using CLayout = GMMA::CLayout_64x16; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct MMA_Traits> -{ - using ValTypeD = float; - using ValTypeA = sparse_elem<2, bfloat16_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = bfloat16_t; - using ValTypeC = float; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_32,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 32>; - using ELayout = GMMA::ELayout_64x32; - using BLayout = GMMA::ABLayout< 32, 32>; - using CLayout = GMMA::CLayout_64x32; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct MMA_Traits> -{ - using ValTypeD = float; - using ValTypeA = sparse_elem<2, bfloat16_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = bfloat16_t; - using ValTypeC = float; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_32,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x32; - using ELayout = GMMA::ELayout_64x32; - using BLayout = GMMA::ABLayout< 32, 32>; - using CLayout = GMMA::CLayout_64x32; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template -struct MMA_Traits> -{ - using ValTypeD = float; - using ValTypeA = sparse_elem<2, bfloat16_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = bfloat16_t; - using ValTypeC = float; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_48,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 32>; - using ELayout = GMMA::ELayout_64x32; - using BLayout = GMMA::ABLayout< 48, 32>; - using CLayout = GMMA::CLayout_64x48; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template -struct MMA_Traits> -{ - using ValTypeD = float; - using ValTypeA = sparse_elem<2, bfloat16_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = bfloat16_t; - using ValTypeC = float; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_48,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x32; - using ELayout = GMMA::ELayout_64x32; - using BLayout = GMMA::ABLayout< 48, 32>; - using CLayout = GMMA::CLayout_64x48; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct MMA_Traits> -{ - using ValTypeD = float; - using ValTypeA = sparse_elem<2, bfloat16_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = bfloat16_t; - using ValTypeC = float; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_64,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 32>; - using ELayout = GMMA::ELayout_64x32; - using BLayout = GMMA::ABLayout< 64, 32>; - using CLayout = GMMA::CLayout_64x64; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct MMA_Traits> -{ - using ValTypeD = float; - using ValTypeA = sparse_elem<2, bfloat16_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = bfloat16_t; - using ValTypeC = float; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_64,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x32; - using ELayout = GMMA::ELayout_64x32; - using BLayout = GMMA::ABLayout< 64, 32>; - using CLayout = GMMA::CLayout_64x64; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template -struct MMA_Traits> -{ - using ValTypeD = float; - using ValTypeA = sparse_elem<2, bfloat16_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = bfloat16_t; - using ValTypeC = float; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_80,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 32>; - using ELayout = GMMA::ELayout_64x32; - using BLayout = GMMA::ABLayout< 80, 32>; - using CLayout = GMMA::CLayout_64x80; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template -struct MMA_Traits> -{ - using ValTypeD = float; - using ValTypeA = sparse_elem<2, bfloat16_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = bfloat16_t; - using ValTypeC = float; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_80,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x32; - using ELayout = GMMA::ELayout_64x32; - using BLayout = GMMA::ABLayout< 80, 32>; - using CLayout = GMMA::CLayout_64x80; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct MMA_Traits> -{ - using ValTypeD = float; - using ValTypeA = sparse_elem<2, bfloat16_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = bfloat16_t; - using ValTypeC = float; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_96,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 32>; - using ELayout = GMMA::ELayout_64x32; - using BLayout = GMMA::ABLayout< 96, 32>; - using CLayout = GMMA::CLayout_64x96; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct MMA_Traits> -{ - using ValTypeD = float; - using ValTypeA = sparse_elem<2, bfloat16_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = bfloat16_t; - using ValTypeC = float; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_96,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x32; - using ELayout = GMMA::ELayout_64x32; - using BLayout = GMMA::ABLayout< 96, 32>; - using CLayout = GMMA::CLayout_64x96; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template -struct MMA_Traits> -{ - using ValTypeD = float; - using ValTypeA = sparse_elem<2, bfloat16_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = bfloat16_t; - using ValTypeC = float; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_112,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 32>; - using ELayout = GMMA::ELayout_64x32; - using BLayout = GMMA::ABLayout<112, 32>; - using CLayout = GMMA::CLayout_64x112; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template -struct MMA_Traits> -{ - using ValTypeD = float; - using ValTypeA = sparse_elem<2, bfloat16_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = bfloat16_t; - using ValTypeC = float; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_112,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x32; - using ELayout = GMMA::ELayout_64x32; - using BLayout = GMMA::ABLayout<112, 32>; - using CLayout = GMMA::CLayout_64x112; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct MMA_Traits> -{ - using ValTypeD = float; - using ValTypeA = sparse_elem<2, bfloat16_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = bfloat16_t; - using ValTypeC = float; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_128,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 32>; - using ELayout = GMMA::ELayout_64x32; - using BLayout = GMMA::ABLayout<128, 32>; - using CLayout = GMMA::CLayout_64x128; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct MMA_Traits> -{ - using ValTypeD = float; - using ValTypeA = sparse_elem<2, bfloat16_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = bfloat16_t; - using ValTypeC = float; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_128,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x32; - using ELayout = GMMA::ELayout_64x32; - using BLayout = GMMA::ABLayout<128, 32>; - using CLayout = GMMA::CLayout_64x128; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template -struct MMA_Traits> -{ - using ValTypeD = float; - using ValTypeA = sparse_elem<2, bfloat16_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = bfloat16_t; - using ValTypeC = float; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_144,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 32>; - using ELayout = GMMA::ELayout_64x32; - using BLayout = GMMA::ABLayout<144, 32>; - using CLayout = GMMA::CLayout_64x144; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template -struct MMA_Traits> -{ - using ValTypeD = float; - using ValTypeA = sparse_elem<2, bfloat16_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = bfloat16_t; - using ValTypeC = float; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_144,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x32; - using ELayout = GMMA::ELayout_64x32; - using BLayout = GMMA::ABLayout<144, 32>; - using CLayout = GMMA::CLayout_64x144; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template -struct MMA_Traits> -{ - using ValTypeD = float; - using ValTypeA = sparse_elem<2, bfloat16_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = bfloat16_t; - using ValTypeC = float; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_160,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 32>; - using ELayout = GMMA::ELayout_64x32; - using BLayout = GMMA::ABLayout<160, 32>; - using CLayout = GMMA::CLayout_64x160; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template -struct MMA_Traits> -{ - using ValTypeD = float; - using ValTypeA = sparse_elem<2, bfloat16_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = bfloat16_t; - using ValTypeC = float; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_160,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x32; - using ELayout = GMMA::ELayout_64x32; - using BLayout = GMMA::ABLayout<160, 32>; - using CLayout = GMMA::CLayout_64x160; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template -struct MMA_Traits> -{ - using ValTypeD = float; - using ValTypeA = sparse_elem<2, bfloat16_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = bfloat16_t; - using ValTypeC = float; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_176,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 32>; - using ELayout = GMMA::ELayout_64x32; - using BLayout = GMMA::ABLayout<176, 32>; - using CLayout = GMMA::CLayout_64x176; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template -struct MMA_Traits> -{ - using ValTypeD = float; - using ValTypeA = sparse_elem<2, bfloat16_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = bfloat16_t; - using ValTypeC = float; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_176,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x32; - using ELayout = GMMA::ELayout_64x32; - using BLayout = GMMA::ABLayout<176, 32>; - using CLayout = GMMA::CLayout_64x176; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct MMA_Traits> -{ - using ValTypeD = float; - using ValTypeA = sparse_elem<2, bfloat16_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = bfloat16_t; - using ValTypeC = float; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_192,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 32>; - using ELayout = GMMA::ELayout_64x32; - using BLayout = GMMA::ABLayout<192, 32>; - using CLayout = GMMA::CLayout_64x192; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct MMA_Traits> -{ - using ValTypeD = float; - using ValTypeA = sparse_elem<2, bfloat16_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = bfloat16_t; - using ValTypeC = float; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_192,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x32; - using ELayout = GMMA::ELayout_64x32; - using BLayout = GMMA::ABLayout<192, 32>; - using CLayout = GMMA::CLayout_64x192; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template -struct MMA_Traits> -{ - using ValTypeD = float; - using ValTypeA = sparse_elem<2, bfloat16_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = bfloat16_t; - using ValTypeC = float; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_208,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 32>; - using ELayout = GMMA::ELayout_64x32; - using BLayout = GMMA::ABLayout<208, 32>; - using CLayout = GMMA::CLayout_64x208; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template -struct MMA_Traits> -{ - using ValTypeD = float; - using ValTypeA = sparse_elem<2, bfloat16_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = bfloat16_t; - using ValTypeC = float; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_208,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x32; - using ELayout = GMMA::ELayout_64x32; - using BLayout = GMMA::ABLayout<208, 32>; - using CLayout = GMMA::CLayout_64x208; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template -struct MMA_Traits> -{ - using ValTypeD = float; - using ValTypeA = sparse_elem<2, bfloat16_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = bfloat16_t; - using ValTypeC = float; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_224,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 32>; - using ELayout = GMMA::ELayout_64x32; - using BLayout = GMMA::ABLayout<224, 32>; - using CLayout = GMMA::CLayout_64x224; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template -struct MMA_Traits> -{ - using ValTypeD = float; - using ValTypeA = sparse_elem<2, bfloat16_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = bfloat16_t; - using ValTypeC = float; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_224,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x32; - using ELayout = GMMA::ELayout_64x32; - using BLayout = GMMA::ABLayout<224, 32>; - using CLayout = GMMA::CLayout_64x224; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template -struct MMA_Traits> -{ - using ValTypeD = float; - using ValTypeA = sparse_elem<2, bfloat16_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = bfloat16_t; - using ValTypeC = float; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_240,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 32>; - using ELayout = GMMA::ELayout_64x32; - using BLayout = GMMA::ABLayout<240, 32>; - using CLayout = GMMA::CLayout_64x240; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template -struct MMA_Traits> -{ - using ValTypeD = float; - using ValTypeA = sparse_elem<2, bfloat16_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = bfloat16_t; - using ValTypeC = float; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_240,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x32; - using ELayout = GMMA::ELayout_64x32; - using BLayout = GMMA::ABLayout<240, 32>; - using CLayout = GMMA::CLayout_64x240; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct MMA_Traits> -{ - using ValTypeD = float; - using ValTypeA = sparse_elem<2, bfloat16_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = bfloat16_t; - using ValTypeC = float; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_256,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 32>; - using ELayout = GMMA::ELayout_64x32; - using BLayout = GMMA::ABLayout<256, 32>; - using CLayout = GMMA::CLayout_64x256; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct MMA_Traits> -{ - using ValTypeD = float; - using ValTypeA = sparse_elem<2, bfloat16_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = bfloat16_t; - using ValTypeC = float; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_256,_32>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x32; - using ELayout = GMMA::ELayout_64x32; - using BLayout = GMMA::ABLayout<256, 32>; - using CLayout = GMMA::CLayout_64x256; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct MMA_Traits> -{ - using ValTypeD = float; - using ValTypeA = sparse_elem<2, tfloat32_t>; - using ValTypeE = sparse_elem<4, uint8_t>; - using ValTypeB = tfloat32_t; - using ValTypeC = float; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_8,_16>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 16>; - using ELayout = GMMA::ELayout_64x16; - using BLayout = GMMA::ABLayout< 8, 16>; - using CLayout = GMMA::CLayout_64x8; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct MMA_Traits> -{ - using ValTypeD = float; - using ValTypeA = sparse_elem<2, tfloat32_t>; - using ValTypeE = sparse_elem<4, uint8_t>; - using ValTypeB = tfloat32_t; - using ValTypeC = float; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_8,_16>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x16; - using ELayout = GMMA::ELayout_64x16; - using BLayout = GMMA::ABLayout< 8, 16>; - using CLayout = GMMA::CLayout_64x8; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct MMA_Traits> -{ - using ValTypeD = float; - using ValTypeA = sparse_elem<2, tfloat32_t>; - using ValTypeE = sparse_elem<4, uint8_t>; - using ValTypeB = tfloat32_t; - using ValTypeC = float; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_16,_16>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 16>; - using ELayout = GMMA::ELayout_64x16; - using BLayout = GMMA::ABLayout< 16, 16>; - using CLayout = GMMA::CLayout_64x16; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct MMA_Traits> -{ - using ValTypeD = float; - using ValTypeA = sparse_elem<2, tfloat32_t>; - using ValTypeE = sparse_elem<4, uint8_t>; - using ValTypeB = tfloat32_t; - using ValTypeC = float; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_16,_16>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x16; - using ELayout = GMMA::ELayout_64x16; - using BLayout = GMMA::ABLayout< 16, 16>; - using CLayout = GMMA::CLayout_64x16; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct MMA_Traits> -{ - using ValTypeD = float; - using ValTypeA = sparse_elem<2, tfloat32_t>; - using ValTypeE = sparse_elem<4, uint8_t>; - using ValTypeB = tfloat32_t; - using ValTypeC = float; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_32,_16>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 16>; - using ELayout = GMMA::ELayout_64x16; - using BLayout = GMMA::ABLayout< 32, 16>; - using CLayout = GMMA::CLayout_64x32; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct MMA_Traits> -{ - using ValTypeD = float; - using ValTypeA = sparse_elem<2, tfloat32_t>; - using ValTypeE = sparse_elem<4, uint8_t>; - using ValTypeB = tfloat32_t; - using ValTypeC = float; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_32,_16>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x16; - using ELayout = GMMA::ELayout_64x16; - using BLayout = GMMA::ABLayout< 32, 16>; - using CLayout = GMMA::CLayout_64x32; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template -struct MMA_Traits> -{ - using ValTypeD = float; - using ValTypeA = sparse_elem<2, tfloat32_t>; - using ValTypeE = sparse_elem<4, uint8_t>; - using ValTypeB = tfloat32_t; - using ValTypeC = float; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_48,_16>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 16>; - using ELayout = GMMA::ELayout_64x16; - using BLayout = GMMA::ABLayout< 48, 16>; - using CLayout = GMMA::CLayout_64x48; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template -struct MMA_Traits> -{ - using ValTypeD = float; - using ValTypeA = sparse_elem<2, tfloat32_t>; - using ValTypeE = sparse_elem<4, uint8_t>; - using ValTypeB = tfloat32_t; - using ValTypeC = float; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_48,_16>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x16; - using ELayout = GMMA::ELayout_64x16; - using BLayout = GMMA::ABLayout< 48, 16>; - using CLayout = GMMA::CLayout_64x48; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct MMA_Traits> -{ - using ValTypeD = float; - using ValTypeA = sparse_elem<2, tfloat32_t>; - using ValTypeE = sparse_elem<4, uint8_t>; - using ValTypeB = tfloat32_t; - using ValTypeC = float; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_64,_16>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 16>; - using ELayout = GMMA::ELayout_64x16; - using BLayout = GMMA::ABLayout< 64, 16>; - using CLayout = GMMA::CLayout_64x64; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct MMA_Traits> -{ - using ValTypeD = float; - using ValTypeA = sparse_elem<2, tfloat32_t>; - using ValTypeE = sparse_elem<4, uint8_t>; - using ValTypeB = tfloat32_t; - using ValTypeC = float; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_64,_16>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x16; - using ELayout = GMMA::ELayout_64x16; - using BLayout = GMMA::ABLayout< 64, 16>; - using CLayout = GMMA::CLayout_64x64; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template -struct MMA_Traits> -{ - using ValTypeD = float; - using ValTypeA = sparse_elem<2, tfloat32_t>; - using ValTypeE = sparse_elem<4, uint8_t>; - using ValTypeB = tfloat32_t; - using ValTypeC = float; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_80,_16>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 16>; - using ELayout = GMMA::ELayout_64x16; - using BLayout = GMMA::ABLayout< 80, 16>; - using CLayout = GMMA::CLayout_64x80; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template -struct MMA_Traits> -{ - using ValTypeD = float; - using ValTypeA = sparse_elem<2, tfloat32_t>; - using ValTypeE = sparse_elem<4, uint8_t>; - using ValTypeB = tfloat32_t; - using ValTypeC = float; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_80,_16>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x16; - using ELayout = GMMA::ELayout_64x16; - using BLayout = GMMA::ABLayout< 80, 16>; - using CLayout = GMMA::CLayout_64x80; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct MMA_Traits> -{ - using ValTypeD = float; - using ValTypeA = sparse_elem<2, tfloat32_t>; - using ValTypeE = sparse_elem<4, uint8_t>; - using ValTypeB = tfloat32_t; - using ValTypeC = float; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_96,_16>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 16>; - using ELayout = GMMA::ELayout_64x16; - using BLayout = GMMA::ABLayout< 96, 16>; - using CLayout = GMMA::CLayout_64x96; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct MMA_Traits> -{ - using ValTypeD = float; - using ValTypeA = sparse_elem<2, tfloat32_t>; - using ValTypeE = sparse_elem<4, uint8_t>; - using ValTypeB = tfloat32_t; - using ValTypeC = float; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_96,_16>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x16; - using ELayout = GMMA::ELayout_64x16; - using BLayout = GMMA::ABLayout< 96, 16>; - using CLayout = GMMA::CLayout_64x96; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template -struct MMA_Traits> -{ - using ValTypeD = float; - using ValTypeA = sparse_elem<2, tfloat32_t>; - using ValTypeE = sparse_elem<4, uint8_t>; - using ValTypeB = tfloat32_t; - using ValTypeC = float; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_112,_16>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 16>; - using ELayout = GMMA::ELayout_64x16; - using BLayout = GMMA::ABLayout<112, 16>; - using CLayout = GMMA::CLayout_64x112; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template -struct MMA_Traits> -{ - using ValTypeD = float; - using ValTypeA = sparse_elem<2, tfloat32_t>; - using ValTypeE = sparse_elem<4, uint8_t>; - using ValTypeB = tfloat32_t; - using ValTypeC = float; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_112,_16>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x16; - using ELayout = GMMA::ELayout_64x16; - using BLayout = GMMA::ABLayout<112, 16>; - using CLayout = GMMA::CLayout_64x112; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct MMA_Traits> -{ - using ValTypeD = float; - using ValTypeA = sparse_elem<2, tfloat32_t>; - using ValTypeE = sparse_elem<4, uint8_t>; - using ValTypeB = tfloat32_t; - using ValTypeC = float; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_128,_16>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 16>; - using ELayout = GMMA::ELayout_64x16; - using BLayout = GMMA::ABLayout<128, 16>; - using CLayout = GMMA::CLayout_64x128; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct MMA_Traits> -{ - using ValTypeD = float; - using ValTypeA = sparse_elem<2, tfloat32_t>; - using ValTypeE = sparse_elem<4, uint8_t>; - using ValTypeB = tfloat32_t; - using ValTypeC = float; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_128,_16>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x16; - using ELayout = GMMA::ELayout_64x16; - using BLayout = GMMA::ABLayout<128, 16>; - using CLayout = GMMA::CLayout_64x128; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template -struct MMA_Traits> -{ - using ValTypeD = float; - using ValTypeA = sparse_elem<2, tfloat32_t>; - using ValTypeE = sparse_elem<4, uint8_t>; - using ValTypeB = tfloat32_t; - using ValTypeC = float; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_144,_16>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 16>; - using ELayout = GMMA::ELayout_64x16; - using BLayout = GMMA::ABLayout<144, 16>; - using CLayout = GMMA::CLayout_64x144; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template -struct MMA_Traits> -{ - using ValTypeD = float; - using ValTypeA = sparse_elem<2, tfloat32_t>; - using ValTypeE = sparse_elem<4, uint8_t>; - using ValTypeB = tfloat32_t; - using ValTypeC = float; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_144,_16>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x16; - using ELayout = GMMA::ELayout_64x16; - using BLayout = GMMA::ABLayout<144, 16>; - using CLayout = GMMA::CLayout_64x144; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template -struct MMA_Traits> -{ - using ValTypeD = float; - using ValTypeA = sparse_elem<2, tfloat32_t>; - using ValTypeE = sparse_elem<4, uint8_t>; - using ValTypeB = tfloat32_t; - using ValTypeC = float; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_160,_16>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 16>; - using ELayout = GMMA::ELayout_64x16; - using BLayout = GMMA::ABLayout<160, 16>; - using CLayout = GMMA::CLayout_64x160; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template -struct MMA_Traits> -{ - using ValTypeD = float; - using ValTypeA = sparse_elem<2, tfloat32_t>; - using ValTypeE = sparse_elem<4, uint8_t>; - using ValTypeB = tfloat32_t; - using ValTypeC = float; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_160,_16>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x16; - using ELayout = GMMA::ELayout_64x16; - using BLayout = GMMA::ABLayout<160, 16>; - using CLayout = GMMA::CLayout_64x160; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template -struct MMA_Traits> -{ - using ValTypeD = float; - using ValTypeA = sparse_elem<2, tfloat32_t>; - using ValTypeE = sparse_elem<4, uint8_t>; - using ValTypeB = tfloat32_t; - using ValTypeC = float; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_176,_16>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 16>; - using ELayout = GMMA::ELayout_64x16; - using BLayout = GMMA::ABLayout<176, 16>; - using CLayout = GMMA::CLayout_64x176; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template -struct MMA_Traits> -{ - using ValTypeD = float; - using ValTypeA = sparse_elem<2, tfloat32_t>; - using ValTypeE = sparse_elem<4, uint8_t>; - using ValTypeB = tfloat32_t; - using ValTypeC = float; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_176,_16>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x16; - using ELayout = GMMA::ELayout_64x16; - using BLayout = GMMA::ABLayout<176, 16>; - using CLayout = GMMA::CLayout_64x176; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct MMA_Traits> -{ - using ValTypeD = float; - using ValTypeA = sparse_elem<2, tfloat32_t>; - using ValTypeE = sparse_elem<4, uint8_t>; - using ValTypeB = tfloat32_t; - using ValTypeC = float; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_192,_16>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 16>; - using ELayout = GMMA::ELayout_64x16; - using BLayout = GMMA::ABLayout<192, 16>; - using CLayout = GMMA::CLayout_64x192; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct MMA_Traits> -{ - using ValTypeD = float; - using ValTypeA = sparse_elem<2, tfloat32_t>; - using ValTypeE = sparse_elem<4, uint8_t>; - using ValTypeB = tfloat32_t; - using ValTypeC = float; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_192,_16>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x16; - using ELayout = GMMA::ELayout_64x16; - using BLayout = GMMA::ABLayout<192, 16>; - using CLayout = GMMA::CLayout_64x192; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template -struct MMA_Traits> -{ - using ValTypeD = float; - using ValTypeA = sparse_elem<2, tfloat32_t>; - using ValTypeE = sparse_elem<4, uint8_t>; - using ValTypeB = tfloat32_t; - using ValTypeC = float; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_208,_16>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 16>; - using ELayout = GMMA::ELayout_64x16; - using BLayout = GMMA::ABLayout<208, 16>; - using CLayout = GMMA::CLayout_64x208; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template -struct MMA_Traits> -{ - using ValTypeD = float; - using ValTypeA = sparse_elem<2, tfloat32_t>; - using ValTypeE = sparse_elem<4, uint8_t>; - using ValTypeB = tfloat32_t; - using ValTypeC = float; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_208,_16>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x16; - using ELayout = GMMA::ELayout_64x16; - using BLayout = GMMA::ABLayout<208, 16>; - using CLayout = GMMA::CLayout_64x208; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template -struct MMA_Traits> -{ - using ValTypeD = float; - using ValTypeA = sparse_elem<2, tfloat32_t>; - using ValTypeE = sparse_elem<4, uint8_t>; - using ValTypeB = tfloat32_t; - using ValTypeC = float; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_224,_16>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 16>; - using ELayout = GMMA::ELayout_64x16; - using BLayout = GMMA::ABLayout<224, 16>; - using CLayout = GMMA::CLayout_64x224; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template -struct MMA_Traits> -{ - using ValTypeD = float; - using ValTypeA = sparse_elem<2, tfloat32_t>; - using ValTypeE = sparse_elem<4, uint8_t>; - using ValTypeB = tfloat32_t; - using ValTypeC = float; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_224,_16>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x16; - using ELayout = GMMA::ELayout_64x16; - using BLayout = GMMA::ABLayout<224, 16>; - using CLayout = GMMA::CLayout_64x224; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template -struct MMA_Traits> -{ - using ValTypeD = float; - using ValTypeA = sparse_elem<2, tfloat32_t>; - using ValTypeE = sparse_elem<4, uint8_t>; - using ValTypeB = tfloat32_t; - using ValTypeC = float; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_240,_16>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 16>; - using ELayout = GMMA::ELayout_64x16; - using BLayout = GMMA::ABLayout<240, 16>; - using CLayout = GMMA::CLayout_64x240; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template -struct MMA_Traits> -{ - using ValTypeD = float; - using ValTypeA = sparse_elem<2, tfloat32_t>; - using ValTypeE = sparse_elem<4, uint8_t>; - using ValTypeB = tfloat32_t; - using ValTypeC = float; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_240,_16>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x16; - using ELayout = GMMA::ELayout_64x16; - using BLayout = GMMA::ABLayout<240, 16>; - using CLayout = GMMA::CLayout_64x240; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct MMA_Traits> -{ - using ValTypeD = float; - using ValTypeA = sparse_elem<2, tfloat32_t>; - using ValTypeE = sparse_elem<4, uint8_t>; - using ValTypeB = tfloat32_t; - using ValTypeC = float; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_256,_16>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 16>; - using ELayout = GMMA::ELayout_64x16; - using BLayout = GMMA::ABLayout<256, 16>; - using CLayout = GMMA::CLayout_64x256; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct MMA_Traits> -{ - using ValTypeD = float; - using ValTypeA = sparse_elem<2, tfloat32_t>; - using ValTypeE = sparse_elem<4, uint8_t>; - using ValTypeB = tfloat32_t; - using ValTypeC = float; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_256,_16>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x16; - using ELayout = GMMA::ELayout_64x16; - using BLayout = GMMA::ABLayout<256, 16>; - using CLayout = GMMA::CLayout_64x256; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct MMA_Traits> -{ - using ValTypeD = int32_t; - using ValTypeA = sparse_elem<2, int8_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = int8_t; - using ValTypeC = int32_t; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_8,_64>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 64>; - using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout< 8, 64>; - using CLayout = GMMA::CLayout_64x8; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct MMA_Traits> -{ - using ValTypeD = int32_t; - using ValTypeA = sparse_elem<2, int8_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = int8_t; - using ValTypeC = int32_t; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_8,_64>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 64>; - using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout< 8, 64>; - using CLayout = GMMA::CLayout_64x8; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct MMA_Traits> -{ - using ValTypeD = int32_t; - using ValTypeA = sparse_elem<2, int8_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = int8_t; - using ValTypeC = int32_t; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_16,_64>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 64>; - using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout< 16, 64>; - using CLayout = GMMA::CLayout_64x16; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct MMA_Traits> -{ - using ValTypeD = int32_t; - using ValTypeA = sparse_elem<2, int8_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = int8_t; - using ValTypeC = int32_t; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_16,_64>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 64>; - using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout< 16, 64>; - using CLayout = GMMA::CLayout_64x16; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct MMA_Traits> -{ - using ValTypeD = int32_t; - using ValTypeA = sparse_elem<2, int8_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = int8_t; - using ValTypeC = int32_t; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_32,_64>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 64>; - using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout< 32, 64>; - using CLayout = GMMA::CLayout_64x32; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct MMA_Traits> -{ - using ValTypeD = int32_t; - using ValTypeA = sparse_elem<2, int8_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = int8_t; - using ValTypeC = int32_t; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_32,_64>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 64>; - using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout< 32, 64>; - using CLayout = GMMA::CLayout_64x32; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template -struct MMA_Traits> -{ - using ValTypeD = int32_t; - using ValTypeA = sparse_elem<2, int8_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = int8_t; - using ValTypeC = int32_t; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_48,_64>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 64>; - using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout< 48, 64>; - using CLayout = GMMA::CLayout_64x48; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template -struct MMA_Traits> -{ - using ValTypeD = int32_t; - using ValTypeA = sparse_elem<2, int8_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = int8_t; - using ValTypeC = int32_t; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_48,_64>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 64>; - using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout< 48, 64>; - using CLayout = GMMA::CLayout_64x48; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct MMA_Traits> -{ - using ValTypeD = int32_t; - using ValTypeA = sparse_elem<2, int8_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = int8_t; - using ValTypeC = int32_t; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_64,_64>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 64>; - using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout< 64, 64>; - using CLayout = GMMA::CLayout_64x64; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct MMA_Traits> -{ - using ValTypeD = int32_t; - using ValTypeA = sparse_elem<2, int8_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = int8_t; - using ValTypeC = int32_t; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_64,_64>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 64>; - using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout< 64, 64>; - using CLayout = GMMA::CLayout_64x64; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template -struct MMA_Traits> -{ - using ValTypeD = int32_t; - using ValTypeA = sparse_elem<2, int8_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = int8_t; - using ValTypeC = int32_t; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_80,_64>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 64>; - using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout< 80, 64>; - using CLayout = GMMA::CLayout_64x80; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template -struct MMA_Traits> -{ - using ValTypeD = int32_t; - using ValTypeA = sparse_elem<2, int8_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = int8_t; - using ValTypeC = int32_t; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_80,_64>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 64>; - using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout< 80, 64>; - using CLayout = GMMA::CLayout_64x80; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct MMA_Traits> -{ - using ValTypeD = int32_t; - using ValTypeA = sparse_elem<2, int8_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = int8_t; - using ValTypeC = int32_t; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_96,_64>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 64>; - using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout< 96, 64>; - using CLayout = GMMA::CLayout_64x96; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct MMA_Traits> -{ - using ValTypeD = int32_t; - using ValTypeA = sparse_elem<2, int8_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = int8_t; - using ValTypeC = int32_t; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_96,_64>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 64>; - using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout< 96, 64>; - using CLayout = GMMA::CLayout_64x96; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template -struct MMA_Traits> -{ - using ValTypeD = int32_t; - using ValTypeA = sparse_elem<2, int8_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = int8_t; - using ValTypeC = int32_t; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_112,_64>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 64>; - using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<112, 64>; - using CLayout = GMMA::CLayout_64x112; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template -struct MMA_Traits> -{ - using ValTypeD = int32_t; - using ValTypeA = sparse_elem<2, int8_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = int8_t; - using ValTypeC = int32_t; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_112,_64>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 64>; - using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<112, 64>; - using CLayout = GMMA::CLayout_64x112; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct MMA_Traits> -{ - using ValTypeD = int32_t; - using ValTypeA = sparse_elem<2, int8_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = int8_t; - using ValTypeC = int32_t; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_128,_64>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 64>; - using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<128, 64>; - using CLayout = GMMA::CLayout_64x128; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct MMA_Traits> -{ - using ValTypeD = int32_t; - using ValTypeA = sparse_elem<2, int8_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = int8_t; - using ValTypeC = int32_t; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_128,_64>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 64>; - using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<128, 64>; - using CLayout = GMMA::CLayout_64x128; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template -struct MMA_Traits> -{ - using ValTypeD = int32_t; - using ValTypeA = sparse_elem<2, int8_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = int8_t; - using ValTypeC = int32_t; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_144,_64>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 64>; - using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<144, 64>; - using CLayout = GMMA::CLayout_64x144; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template -struct MMA_Traits> -{ - using ValTypeD = int32_t; - using ValTypeA = sparse_elem<2, int8_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = int8_t; - using ValTypeC = int32_t; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_144,_64>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 64>; - using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<144, 64>; - using CLayout = GMMA::CLayout_64x144; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template -struct MMA_Traits> -{ - using ValTypeD = int32_t; - using ValTypeA = sparse_elem<2, int8_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = int8_t; - using ValTypeC = int32_t; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_160,_64>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 64>; - using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<160, 64>; - using CLayout = GMMA::CLayout_64x160; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template -struct MMA_Traits> -{ - using ValTypeD = int32_t; - using ValTypeA = sparse_elem<2, int8_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = int8_t; - using ValTypeC = int32_t; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_160,_64>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 64>; - using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<160, 64>; - using CLayout = GMMA::CLayout_64x160; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template -struct MMA_Traits> -{ - using ValTypeD = int32_t; - using ValTypeA = sparse_elem<2, int8_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = int8_t; - using ValTypeC = int32_t; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_176,_64>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 64>; - using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<176, 64>; - using CLayout = GMMA::CLayout_64x176; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template -struct MMA_Traits> -{ - using ValTypeD = int32_t; - using ValTypeA = sparse_elem<2, int8_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = int8_t; - using ValTypeC = int32_t; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_176,_64>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 64>; - using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<176, 64>; - using CLayout = GMMA::CLayout_64x176; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct MMA_Traits> -{ - using ValTypeD = int32_t; - using ValTypeA = sparse_elem<2, int8_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = int8_t; - using ValTypeC = int32_t; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_192,_64>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 64>; - using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<192, 64>; - using CLayout = GMMA::CLayout_64x192; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct MMA_Traits> -{ - using ValTypeD = int32_t; - using ValTypeA = sparse_elem<2, int8_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = int8_t; - using ValTypeC = int32_t; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_192,_64>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 64>; - using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<192, 64>; - using CLayout = GMMA::CLayout_64x192; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template -struct MMA_Traits> -{ - using ValTypeD = int32_t; - using ValTypeA = sparse_elem<2, int8_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = int8_t; - using ValTypeC = int32_t; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_208,_64>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 64>; - using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<208, 64>; - using CLayout = GMMA::CLayout_64x208; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template -struct MMA_Traits> -{ - using ValTypeD = int32_t; - using ValTypeA = sparse_elem<2, int8_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = int8_t; - using ValTypeC = int32_t; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_208,_64>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 64>; - using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<208, 64>; - using CLayout = GMMA::CLayout_64x208; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template -struct MMA_Traits> -{ - using ValTypeD = int32_t; - using ValTypeA = sparse_elem<2, int8_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = int8_t; - using ValTypeC = int32_t; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_224,_64>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 64>; - using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<224, 64>; - using CLayout = GMMA::CLayout_64x224; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template -struct MMA_Traits> -{ - using ValTypeD = int32_t; - using ValTypeA = sparse_elem<2, int8_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = int8_t; - using ValTypeC = int32_t; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_224,_64>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 64>; - using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<224, 64>; - using CLayout = GMMA::CLayout_64x224; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template -struct MMA_Traits> -{ - using ValTypeD = int32_t; - using ValTypeA = sparse_elem<2, int8_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = int8_t; - using ValTypeC = int32_t; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_240,_64>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 64>; - using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<240, 64>; - using CLayout = GMMA::CLayout_64x240; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template -struct MMA_Traits> -{ - using ValTypeD = int32_t; - using ValTypeA = sparse_elem<2, int8_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = int8_t; - using ValTypeC = int32_t; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_240,_64>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 64>; - using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<240, 64>; - using CLayout = GMMA::CLayout_64x240; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct MMA_Traits> -{ - using ValTypeD = int32_t; - using ValTypeA = sparse_elem<2, int8_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = int8_t; - using ValTypeC = int32_t; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_256,_64>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 64>; - using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<256, 64>; - using CLayout = GMMA::CLayout_64x256; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct MMA_Traits> -{ - using ValTypeD = int32_t; - using ValTypeA = sparse_elem<2, int8_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = int8_t; - using ValTypeC = int32_t; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_256,_64>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 64>; - using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<256, 64>; - using CLayout = GMMA::CLayout_64x256; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct MMA_Traits> -{ - using ValTypeD = int32_t; - using ValTypeA = sparse_elem<2, int8_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = int8_t; - using ValTypeC = int32_t; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_8,_64>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x64; - using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout< 8, 64>; - using CLayout = GMMA::CLayout_64x8; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct MMA_Traits> -{ - using ValTypeD = int32_t; - using ValTypeA = sparse_elem<2, int8_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = int8_t; - using ValTypeC = int32_t; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_8,_64>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x64; - using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout< 8, 64>; - using CLayout = GMMA::CLayout_64x8; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct MMA_Traits> -{ - using ValTypeD = int32_t; - using ValTypeA = sparse_elem<2, int8_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = int8_t; - using ValTypeC = int32_t; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_16,_64>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x64; - using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout< 16, 64>; - using CLayout = GMMA::CLayout_64x16; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct MMA_Traits> -{ - using ValTypeD = int32_t; - using ValTypeA = sparse_elem<2, int8_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = int8_t; - using ValTypeC = int32_t; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_16,_64>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x64; - using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout< 16, 64>; - using CLayout = GMMA::CLayout_64x16; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct MMA_Traits> -{ - using ValTypeD = int32_t; - using ValTypeA = sparse_elem<2, int8_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = int8_t; - using ValTypeC = int32_t; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_32,_64>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x64; - using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout< 32, 64>; - using CLayout = GMMA::CLayout_64x32; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct MMA_Traits> -{ - using ValTypeD = int32_t; - using ValTypeA = sparse_elem<2, int8_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = int8_t; - using ValTypeC = int32_t; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_32,_64>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x64; - using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout< 32, 64>; - using CLayout = GMMA::CLayout_64x32; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template -struct MMA_Traits> -{ - using ValTypeD = int32_t; - using ValTypeA = sparse_elem<2, int8_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = int8_t; - using ValTypeC = int32_t; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_48,_64>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x64; - using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout< 48, 64>; - using CLayout = GMMA::CLayout_64x48; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template -struct MMA_Traits> -{ - using ValTypeD = int32_t; - using ValTypeA = sparse_elem<2, int8_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = int8_t; - using ValTypeC = int32_t; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_48,_64>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x64; - using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout< 48, 64>; - using CLayout = GMMA::CLayout_64x48; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct MMA_Traits> -{ - using ValTypeD = int32_t; - using ValTypeA = sparse_elem<2, int8_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = int8_t; - using ValTypeC = int32_t; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_64,_64>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x64; - using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout< 64, 64>; - using CLayout = GMMA::CLayout_64x64; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct MMA_Traits> -{ - using ValTypeD = int32_t; - using ValTypeA = sparse_elem<2, int8_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = int8_t; - using ValTypeC = int32_t; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_64,_64>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x64; - using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout< 64, 64>; - using CLayout = GMMA::CLayout_64x64; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template -struct MMA_Traits> -{ - using ValTypeD = int32_t; - using ValTypeA = sparse_elem<2, int8_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = int8_t; - using ValTypeC = int32_t; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_80,_64>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x64; - using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout< 80, 64>; - using CLayout = GMMA::CLayout_64x80; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template -struct MMA_Traits> -{ - using ValTypeD = int32_t; - using ValTypeA = sparse_elem<2, int8_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = int8_t; - using ValTypeC = int32_t; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_80,_64>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x64; - using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout< 80, 64>; - using CLayout = GMMA::CLayout_64x80; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct MMA_Traits> -{ - using ValTypeD = int32_t; - using ValTypeA = sparse_elem<2, int8_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = int8_t; - using ValTypeC = int32_t; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_96,_64>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x64; - using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout< 96, 64>; - using CLayout = GMMA::CLayout_64x96; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct MMA_Traits> -{ - using ValTypeD = int32_t; - using ValTypeA = sparse_elem<2, int8_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = int8_t; - using ValTypeC = int32_t; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_96,_64>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x64; - using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout< 96, 64>; - using CLayout = GMMA::CLayout_64x96; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template -struct MMA_Traits> -{ - using ValTypeD = int32_t; - using ValTypeA = sparse_elem<2, int8_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = int8_t; - using ValTypeC = int32_t; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_112,_64>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x64; - using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<112, 64>; - using CLayout = GMMA::CLayout_64x112; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template -struct MMA_Traits> -{ - using ValTypeD = int32_t; - using ValTypeA = sparse_elem<2, int8_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = int8_t; - using ValTypeC = int32_t; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_112,_64>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x64; - using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<112, 64>; - using CLayout = GMMA::CLayout_64x112; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct MMA_Traits> -{ - using ValTypeD = int32_t; - using ValTypeA = sparse_elem<2, int8_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = int8_t; - using ValTypeC = int32_t; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_128,_64>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x64; - using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<128, 64>; - using CLayout = GMMA::CLayout_64x128; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct MMA_Traits> -{ - using ValTypeD = int32_t; - using ValTypeA = sparse_elem<2, int8_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = int8_t; - using ValTypeC = int32_t; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_128,_64>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x64; - using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<128, 64>; - using CLayout = GMMA::CLayout_64x128; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template -struct MMA_Traits> -{ - using ValTypeD = int32_t; - using ValTypeA = sparse_elem<2, int8_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = int8_t; - using ValTypeC = int32_t; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_144,_64>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x64; - using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<144, 64>; - using CLayout = GMMA::CLayout_64x144; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template -struct MMA_Traits> -{ - using ValTypeD = int32_t; - using ValTypeA = sparse_elem<2, int8_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = int8_t; - using ValTypeC = int32_t; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_144,_64>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x64; - using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<144, 64>; - using CLayout = GMMA::CLayout_64x144; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template -struct MMA_Traits> -{ - using ValTypeD = int32_t; - using ValTypeA = sparse_elem<2, int8_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = int8_t; - using ValTypeC = int32_t; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_160,_64>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x64; - using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<160, 64>; - using CLayout = GMMA::CLayout_64x160; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template -struct MMA_Traits> -{ - using ValTypeD = int32_t; - using ValTypeA = sparse_elem<2, int8_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = int8_t; - using ValTypeC = int32_t; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_160,_64>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x64; - using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<160, 64>; - using CLayout = GMMA::CLayout_64x160; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template -struct MMA_Traits> -{ - using ValTypeD = int32_t; - using ValTypeA = sparse_elem<2, int8_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = int8_t; - using ValTypeC = int32_t; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_176,_64>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x64; - using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<176, 64>; - using CLayout = GMMA::CLayout_64x176; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template -struct MMA_Traits> -{ - using ValTypeD = int32_t; - using ValTypeA = sparse_elem<2, int8_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = int8_t; - using ValTypeC = int32_t; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_176,_64>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x64; - using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<176, 64>; - using CLayout = GMMA::CLayout_64x176; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct MMA_Traits> -{ - using ValTypeD = int32_t; - using ValTypeA = sparse_elem<2, int8_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = int8_t; - using ValTypeC = int32_t; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_192,_64>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x64; - using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<192, 64>; - using CLayout = GMMA::CLayout_64x192; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct MMA_Traits> -{ - using ValTypeD = int32_t; - using ValTypeA = sparse_elem<2, int8_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = int8_t; - using ValTypeC = int32_t; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_192,_64>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x64; - using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<192, 64>; - using CLayout = GMMA::CLayout_64x192; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template -struct MMA_Traits> -{ - using ValTypeD = int32_t; - using ValTypeA = sparse_elem<2, int8_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = int8_t; - using ValTypeC = int32_t; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_208,_64>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x64; - using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<208, 64>; - using CLayout = GMMA::CLayout_64x208; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template -struct MMA_Traits> -{ - using ValTypeD = int32_t; - using ValTypeA = sparse_elem<2, int8_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = int8_t; - using ValTypeC = int32_t; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_208,_64>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x64; - using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<208, 64>; - using CLayout = GMMA::CLayout_64x208; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template -struct MMA_Traits> -{ - using ValTypeD = int32_t; - using ValTypeA = sparse_elem<2, int8_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = int8_t; - using ValTypeC = int32_t; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_224,_64>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x64; - using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<224, 64>; - using CLayout = GMMA::CLayout_64x224; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template -struct MMA_Traits> -{ - using ValTypeD = int32_t; - using ValTypeA = sparse_elem<2, int8_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = int8_t; - using ValTypeC = int32_t; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_224,_64>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x64; - using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<224, 64>; - using CLayout = GMMA::CLayout_64x224; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template -struct MMA_Traits> -{ - using ValTypeD = int32_t; - using ValTypeA = sparse_elem<2, int8_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = int8_t; - using ValTypeC = int32_t; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_240,_64>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x64; - using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<240, 64>; - using CLayout = GMMA::CLayout_64x240; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template -struct MMA_Traits> -{ - using ValTypeD = int32_t; - using ValTypeA = sparse_elem<2, int8_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = int8_t; - using ValTypeC = int32_t; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_240,_64>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x64; - using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<240, 64>; - using CLayout = GMMA::CLayout_64x240; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct MMA_Traits> -{ - using ValTypeD = int32_t; - using ValTypeA = sparse_elem<2, int8_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = int8_t; - using ValTypeC = int32_t; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_256,_64>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x64; - using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<256, 64>; - using CLayout = GMMA::CLayout_64x256; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct MMA_Traits> -{ - using ValTypeD = int32_t; - using ValTypeA = sparse_elem<2, int8_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = int8_t; - using ValTypeC = int32_t; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_256,_64>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x64; - using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<256, 64>; - using CLayout = GMMA::CLayout_64x256; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct MMA_Traits> -{ - using ValTypeD = int32_t; - using ValTypeA = sparse_elem<2, int8_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = uint8_t; - using ValTypeC = int32_t; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_8,_64>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 64>; - using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout< 8, 64>; - using CLayout = GMMA::CLayout_64x8; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct MMA_Traits> -{ - using ValTypeD = int32_t; - using ValTypeA = sparse_elem<2, int8_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = uint8_t; - using ValTypeC = int32_t; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_8,_64>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 64>; - using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout< 8, 64>; - using CLayout = GMMA::CLayout_64x8; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct MMA_Traits> -{ - using ValTypeD = int32_t; - using ValTypeA = sparse_elem<2, int8_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = uint8_t; - using ValTypeC = int32_t; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_16,_64>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 64>; - using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout< 16, 64>; - using CLayout = GMMA::CLayout_64x16; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct MMA_Traits> -{ - using ValTypeD = int32_t; - using ValTypeA = sparse_elem<2, int8_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = uint8_t; - using ValTypeC = int32_t; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_16,_64>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 64>; - using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout< 16, 64>; - using CLayout = GMMA::CLayout_64x16; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct MMA_Traits> -{ - using ValTypeD = int32_t; - using ValTypeA = sparse_elem<2, int8_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = uint8_t; - using ValTypeC = int32_t; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_32,_64>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 64>; - using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout< 32, 64>; - using CLayout = GMMA::CLayout_64x32; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct MMA_Traits> -{ - using ValTypeD = int32_t; - using ValTypeA = sparse_elem<2, int8_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = uint8_t; - using ValTypeC = int32_t; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_32,_64>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 64>; - using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout< 32, 64>; - using CLayout = GMMA::CLayout_64x32; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template -struct MMA_Traits> -{ - using ValTypeD = int32_t; - using ValTypeA = sparse_elem<2, int8_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = uint8_t; - using ValTypeC = int32_t; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_48,_64>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 64>; - using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout< 48, 64>; - using CLayout = GMMA::CLayout_64x48; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template -struct MMA_Traits> -{ - using ValTypeD = int32_t; - using ValTypeA = sparse_elem<2, int8_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = uint8_t; - using ValTypeC = int32_t; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_48,_64>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 64>; - using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout< 48, 64>; - using CLayout = GMMA::CLayout_64x48; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct MMA_Traits> -{ - using ValTypeD = int32_t; - using ValTypeA = sparse_elem<2, int8_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = uint8_t; - using ValTypeC = int32_t; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_64,_64>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 64>; - using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout< 64, 64>; - using CLayout = GMMA::CLayout_64x64; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct MMA_Traits> -{ - using ValTypeD = int32_t; - using ValTypeA = sparse_elem<2, int8_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = uint8_t; - using ValTypeC = int32_t; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_64,_64>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 64>; - using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout< 64, 64>; - using CLayout = GMMA::CLayout_64x64; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template -struct MMA_Traits> -{ - using ValTypeD = int32_t; - using ValTypeA = sparse_elem<2, int8_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = uint8_t; - using ValTypeC = int32_t; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_80,_64>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 64>; - using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout< 80, 64>; - using CLayout = GMMA::CLayout_64x80; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template -struct MMA_Traits> -{ - using ValTypeD = int32_t; - using ValTypeA = sparse_elem<2, int8_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = uint8_t; - using ValTypeC = int32_t; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_80,_64>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 64>; - using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout< 80, 64>; - using CLayout = GMMA::CLayout_64x80; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct MMA_Traits> -{ - using ValTypeD = int32_t; - using ValTypeA = sparse_elem<2, int8_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = uint8_t; - using ValTypeC = int32_t; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_96,_64>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 64>; - using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout< 96, 64>; - using CLayout = GMMA::CLayout_64x96; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct MMA_Traits> -{ - using ValTypeD = int32_t; - using ValTypeA = sparse_elem<2, int8_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = uint8_t; - using ValTypeC = int32_t; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_96,_64>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 64>; - using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout< 96, 64>; - using CLayout = GMMA::CLayout_64x96; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template -struct MMA_Traits> -{ - using ValTypeD = int32_t; - using ValTypeA = sparse_elem<2, int8_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = uint8_t; - using ValTypeC = int32_t; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_112,_64>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 64>; - using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<112, 64>; - using CLayout = GMMA::CLayout_64x112; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template -struct MMA_Traits> -{ - using ValTypeD = int32_t; - using ValTypeA = sparse_elem<2, int8_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = uint8_t; - using ValTypeC = int32_t; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_112,_64>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 64>; - using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<112, 64>; - using CLayout = GMMA::CLayout_64x112; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct MMA_Traits> -{ - using ValTypeD = int32_t; - using ValTypeA = sparse_elem<2, int8_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = uint8_t; - using ValTypeC = int32_t; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_128,_64>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 64>; - using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<128, 64>; - using CLayout = GMMA::CLayout_64x128; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct MMA_Traits> -{ - using ValTypeD = int32_t; - using ValTypeA = sparse_elem<2, int8_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = uint8_t; - using ValTypeC = int32_t; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_128,_64>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 64>; - using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<128, 64>; - using CLayout = GMMA::CLayout_64x128; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template -struct MMA_Traits> -{ - using ValTypeD = int32_t; - using ValTypeA = sparse_elem<2, int8_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = uint8_t; - using ValTypeC = int32_t; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_144,_64>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 64>; - using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<144, 64>; - using CLayout = GMMA::CLayout_64x144; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template -struct MMA_Traits> -{ - using ValTypeD = int32_t; - using ValTypeA = sparse_elem<2, int8_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = uint8_t; - using ValTypeC = int32_t; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_144,_64>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 64>; - using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<144, 64>; - using CLayout = GMMA::CLayout_64x144; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template -struct MMA_Traits> -{ - using ValTypeD = int32_t; - using ValTypeA = sparse_elem<2, int8_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = uint8_t; - using ValTypeC = int32_t; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_160,_64>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 64>; - using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<160, 64>; - using CLayout = GMMA::CLayout_64x160; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template -struct MMA_Traits> -{ - using ValTypeD = int32_t; - using ValTypeA = sparse_elem<2, int8_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = uint8_t; - using ValTypeC = int32_t; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_160,_64>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 64>; - using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<160, 64>; - using CLayout = GMMA::CLayout_64x160; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template -struct MMA_Traits> -{ - using ValTypeD = int32_t; - using ValTypeA = sparse_elem<2, int8_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = uint8_t; - using ValTypeC = int32_t; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_176,_64>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 64>; - using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<176, 64>; - using CLayout = GMMA::CLayout_64x176; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template -struct MMA_Traits> -{ - using ValTypeD = int32_t; - using ValTypeA = sparse_elem<2, int8_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = uint8_t; - using ValTypeC = int32_t; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_176,_64>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 64>; - using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<176, 64>; - using CLayout = GMMA::CLayout_64x176; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct MMA_Traits> -{ - using ValTypeD = int32_t; - using ValTypeA = sparse_elem<2, int8_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = uint8_t; - using ValTypeC = int32_t; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_192,_64>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 64>; - using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<192, 64>; - using CLayout = GMMA::CLayout_64x192; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct MMA_Traits> -{ - using ValTypeD = int32_t; - using ValTypeA = sparse_elem<2, int8_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = uint8_t; - using ValTypeC = int32_t; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_192,_64>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 64>; - using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<192, 64>; - using CLayout = GMMA::CLayout_64x192; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template -struct MMA_Traits> -{ - using ValTypeD = int32_t; - using ValTypeA = sparse_elem<2, int8_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = uint8_t; - using ValTypeC = int32_t; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_208,_64>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 64>; - using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<208, 64>; - using CLayout = GMMA::CLayout_64x208; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template -struct MMA_Traits> -{ - using ValTypeD = int32_t; - using ValTypeA = sparse_elem<2, int8_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = uint8_t; - using ValTypeC = int32_t; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_208,_64>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 64>; - using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<208, 64>; - using CLayout = GMMA::CLayout_64x208; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template -struct MMA_Traits> -{ - using ValTypeD = int32_t; - using ValTypeA = sparse_elem<2, int8_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = uint8_t; - using ValTypeC = int32_t; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_224,_64>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 64>; - using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<224, 64>; - using CLayout = GMMA::CLayout_64x224; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template -struct MMA_Traits> -{ - using ValTypeD = int32_t; - using ValTypeA = sparse_elem<2, int8_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = uint8_t; - using ValTypeC = int32_t; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_224,_64>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 64>; - using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<224, 64>; - using CLayout = GMMA::CLayout_64x224; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template -struct MMA_Traits> -{ - using ValTypeD = int32_t; - using ValTypeA = sparse_elem<2, int8_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = uint8_t; - using ValTypeC = int32_t; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_240,_64>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 64>; - using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<240, 64>; - using CLayout = GMMA::CLayout_64x240; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template -struct MMA_Traits> -{ - using ValTypeD = int32_t; - using ValTypeA = sparse_elem<2, int8_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = uint8_t; - using ValTypeC = int32_t; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_240,_64>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 64>; - using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<240, 64>; - using CLayout = GMMA::CLayout_64x240; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct MMA_Traits> -{ - using ValTypeD = int32_t; - using ValTypeA = sparse_elem<2, int8_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = uint8_t; - using ValTypeC = int32_t; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_256,_64>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 64>; - using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<256, 64>; - using CLayout = GMMA::CLayout_64x256; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct MMA_Traits> -{ - using ValTypeD = int32_t; - using ValTypeA = sparse_elem<2, int8_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = uint8_t; - using ValTypeC = int32_t; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_256,_64>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 64>; - using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<256, 64>; - using CLayout = GMMA::CLayout_64x256; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct MMA_Traits> -{ - using ValTypeD = int32_t; - using ValTypeA = sparse_elem<2, int8_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = uint8_t; - using ValTypeC = int32_t; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_8,_64>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x64; - using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout< 8, 64>; - using CLayout = GMMA::CLayout_64x8; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct MMA_Traits> -{ - using ValTypeD = int32_t; - using ValTypeA = sparse_elem<2, int8_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = uint8_t; - using ValTypeC = int32_t; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_8,_64>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x64; - using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout< 8, 64>; - using CLayout = GMMA::CLayout_64x8; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct MMA_Traits> -{ - using ValTypeD = int32_t; - using ValTypeA = sparse_elem<2, int8_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = uint8_t; - using ValTypeC = int32_t; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_16,_64>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x64; - using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout< 16, 64>; - using CLayout = GMMA::CLayout_64x16; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct MMA_Traits> -{ - using ValTypeD = int32_t; - using ValTypeA = sparse_elem<2, int8_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = uint8_t; - using ValTypeC = int32_t; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_16,_64>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x64; - using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout< 16, 64>; - using CLayout = GMMA::CLayout_64x16; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct MMA_Traits> -{ - using ValTypeD = int32_t; - using ValTypeA = sparse_elem<2, int8_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = uint8_t; - using ValTypeC = int32_t; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_32,_64>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x64; - using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout< 32, 64>; - using CLayout = GMMA::CLayout_64x32; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct MMA_Traits> -{ - using ValTypeD = int32_t; - using ValTypeA = sparse_elem<2, int8_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = uint8_t; - using ValTypeC = int32_t; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_32,_64>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x64; - using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout< 32, 64>; - using CLayout = GMMA::CLayout_64x32; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template -struct MMA_Traits> -{ - using ValTypeD = int32_t; - using ValTypeA = sparse_elem<2, int8_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = uint8_t; - using ValTypeC = int32_t; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_48,_64>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x64; - using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout< 48, 64>; - using CLayout = GMMA::CLayout_64x48; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template -struct MMA_Traits> -{ - using ValTypeD = int32_t; - using ValTypeA = sparse_elem<2, int8_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = uint8_t; - using ValTypeC = int32_t; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_48,_64>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x64; - using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout< 48, 64>; - using CLayout = GMMA::CLayout_64x48; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct MMA_Traits> -{ - using ValTypeD = int32_t; - using ValTypeA = sparse_elem<2, int8_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = uint8_t; - using ValTypeC = int32_t; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_64,_64>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x64; - using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout< 64, 64>; - using CLayout = GMMA::CLayout_64x64; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct MMA_Traits> -{ - using ValTypeD = int32_t; - using ValTypeA = sparse_elem<2, int8_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = uint8_t; - using ValTypeC = int32_t; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_64,_64>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x64; - using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout< 64, 64>; - using CLayout = GMMA::CLayout_64x64; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template -struct MMA_Traits> -{ - using ValTypeD = int32_t; - using ValTypeA = sparse_elem<2, int8_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = uint8_t; - using ValTypeC = int32_t; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_80,_64>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x64; - using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout< 80, 64>; - using CLayout = GMMA::CLayout_64x80; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template -struct MMA_Traits> -{ - using ValTypeD = int32_t; - using ValTypeA = sparse_elem<2, int8_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = uint8_t; - using ValTypeC = int32_t; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_80,_64>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x64; - using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout< 80, 64>; - using CLayout = GMMA::CLayout_64x80; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct MMA_Traits> -{ - using ValTypeD = int32_t; - using ValTypeA = sparse_elem<2, int8_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = uint8_t; - using ValTypeC = int32_t; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_96,_64>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x64; - using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout< 96, 64>; - using CLayout = GMMA::CLayout_64x96; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct MMA_Traits> -{ - using ValTypeD = int32_t; - using ValTypeA = sparse_elem<2, int8_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = uint8_t; - using ValTypeC = int32_t; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_96,_64>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x64; - using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout< 96, 64>; - using CLayout = GMMA::CLayout_64x96; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template -struct MMA_Traits> -{ - using ValTypeD = int32_t; - using ValTypeA = sparse_elem<2, int8_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = uint8_t; - using ValTypeC = int32_t; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_112,_64>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x64; - using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<112, 64>; - using CLayout = GMMA::CLayout_64x112; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template -struct MMA_Traits> -{ - using ValTypeD = int32_t; - using ValTypeA = sparse_elem<2, int8_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = uint8_t; - using ValTypeC = int32_t; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_112,_64>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x64; - using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<112, 64>; - using CLayout = GMMA::CLayout_64x112; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct MMA_Traits> -{ - using ValTypeD = int32_t; - using ValTypeA = sparse_elem<2, int8_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = uint8_t; - using ValTypeC = int32_t; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_128,_64>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x64; - using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<128, 64>; - using CLayout = GMMA::CLayout_64x128; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct MMA_Traits> -{ - using ValTypeD = int32_t; - using ValTypeA = sparse_elem<2, int8_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = uint8_t; - using ValTypeC = int32_t; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_128,_64>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x64; - using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<128, 64>; - using CLayout = GMMA::CLayout_64x128; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template -struct MMA_Traits> -{ - using ValTypeD = int32_t; - using ValTypeA = sparse_elem<2, int8_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = uint8_t; - using ValTypeC = int32_t; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_144,_64>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x64; - using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<144, 64>; - using CLayout = GMMA::CLayout_64x144; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template -struct MMA_Traits> -{ - using ValTypeD = int32_t; - using ValTypeA = sparse_elem<2, int8_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = uint8_t; - using ValTypeC = int32_t; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_144,_64>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x64; - using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<144, 64>; - using CLayout = GMMA::CLayout_64x144; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template -struct MMA_Traits> -{ - using ValTypeD = int32_t; - using ValTypeA = sparse_elem<2, int8_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = uint8_t; - using ValTypeC = int32_t; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_160,_64>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x64; - using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<160, 64>; - using CLayout = GMMA::CLayout_64x160; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template -struct MMA_Traits> -{ - using ValTypeD = int32_t; - using ValTypeA = sparse_elem<2, int8_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = uint8_t; - using ValTypeC = int32_t; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_160,_64>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x64; - using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<160, 64>; - using CLayout = GMMA::CLayout_64x160; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template -struct MMA_Traits> -{ - using ValTypeD = int32_t; - using ValTypeA = sparse_elem<2, int8_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = uint8_t; - using ValTypeC = int32_t; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_176,_64>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x64; - using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<176, 64>; - using CLayout = GMMA::CLayout_64x176; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template -struct MMA_Traits> -{ - using ValTypeD = int32_t; - using ValTypeA = sparse_elem<2, int8_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = uint8_t; - using ValTypeC = int32_t; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_176,_64>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x64; - using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<176, 64>; - using CLayout = GMMA::CLayout_64x176; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct MMA_Traits> -{ - using ValTypeD = int32_t; - using ValTypeA = sparse_elem<2, int8_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = uint8_t; - using ValTypeC = int32_t; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_192,_64>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x64; - using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<192, 64>; - using CLayout = GMMA::CLayout_64x192; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct MMA_Traits> -{ - using ValTypeD = int32_t; - using ValTypeA = sparse_elem<2, int8_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = uint8_t; - using ValTypeC = int32_t; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_192,_64>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x64; - using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<192, 64>; - using CLayout = GMMA::CLayout_64x192; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template -struct MMA_Traits> -{ - using ValTypeD = int32_t; - using ValTypeA = sparse_elem<2, int8_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = uint8_t; - using ValTypeC = int32_t; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_208,_64>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x64; - using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<208, 64>; - using CLayout = GMMA::CLayout_64x208; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template -struct MMA_Traits> -{ - using ValTypeD = int32_t; - using ValTypeA = sparse_elem<2, int8_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = uint8_t; - using ValTypeC = int32_t; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_208,_64>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x64; - using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<208, 64>; - using CLayout = GMMA::CLayout_64x208; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template -struct MMA_Traits> -{ - using ValTypeD = int32_t; - using ValTypeA = sparse_elem<2, int8_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = uint8_t; - using ValTypeC = int32_t; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_224,_64>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x64; - using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<224, 64>; - using CLayout = GMMA::CLayout_64x224; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template -struct MMA_Traits> -{ - using ValTypeD = int32_t; - using ValTypeA = sparse_elem<2, int8_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = uint8_t; - using ValTypeC = int32_t; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_224,_64>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x64; - using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<224, 64>; - using CLayout = GMMA::CLayout_64x224; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template -struct MMA_Traits> -{ - using ValTypeD = int32_t; - using ValTypeA = sparse_elem<2, int8_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = uint8_t; - using ValTypeC = int32_t; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_240,_64>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x64; - using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<240, 64>; - using CLayout = GMMA::CLayout_64x240; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template -struct MMA_Traits> -{ - using ValTypeD = int32_t; - using ValTypeA = sparse_elem<2, int8_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = uint8_t; - using ValTypeC = int32_t; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_240,_64>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x64; - using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<240, 64>; - using CLayout = GMMA::CLayout_64x240; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct MMA_Traits> -{ - using ValTypeD = int32_t; - using ValTypeA = sparse_elem<2, int8_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = uint8_t; - using ValTypeC = int32_t; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_256,_64>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x64; - using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<256, 64>; - using CLayout = GMMA::CLayout_64x256; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct MMA_Traits> -{ - using ValTypeD = int32_t; - using ValTypeA = sparse_elem<2, int8_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = uint8_t; - using ValTypeC = int32_t; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_256,_64>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x64; - using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<256, 64>; - using CLayout = GMMA::CLayout_64x256; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct MMA_Traits> -{ - using ValTypeD = int32_t; - using ValTypeA = sparse_elem<2, uint8_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = int8_t; - using ValTypeC = int32_t; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_8,_64>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 64>; - using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout< 8, 64>; - using CLayout = GMMA::CLayout_64x8; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct MMA_Traits> -{ - using ValTypeD = int32_t; - using ValTypeA = sparse_elem<2, uint8_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = int8_t; - using ValTypeC = int32_t; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_8,_64>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 64>; - using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout< 8, 64>; - using CLayout = GMMA::CLayout_64x8; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct MMA_Traits> -{ - using ValTypeD = int32_t; - using ValTypeA = sparse_elem<2, uint8_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = int8_t; - using ValTypeC = int32_t; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_16,_64>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 64>; - using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout< 16, 64>; - using CLayout = GMMA::CLayout_64x16; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct MMA_Traits> -{ - using ValTypeD = int32_t; - using ValTypeA = sparse_elem<2, uint8_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = int8_t; - using ValTypeC = int32_t; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_16,_64>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 64>; - using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout< 16, 64>; - using CLayout = GMMA::CLayout_64x16; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct MMA_Traits> -{ - using ValTypeD = int32_t; - using ValTypeA = sparse_elem<2, uint8_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = int8_t; - using ValTypeC = int32_t; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_32,_64>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 64>; - using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout< 32, 64>; - using CLayout = GMMA::CLayout_64x32; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct MMA_Traits> -{ - using ValTypeD = int32_t; - using ValTypeA = sparse_elem<2, uint8_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = int8_t; - using ValTypeC = int32_t; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_32,_64>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 64>; - using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout< 32, 64>; - using CLayout = GMMA::CLayout_64x32; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template -struct MMA_Traits> -{ - using ValTypeD = int32_t; - using ValTypeA = sparse_elem<2, uint8_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = int8_t; - using ValTypeC = int32_t; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_48,_64>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 64>; - using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout< 48, 64>; - using CLayout = GMMA::CLayout_64x48; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template -struct MMA_Traits> -{ - using ValTypeD = int32_t; - using ValTypeA = sparse_elem<2, uint8_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = int8_t; - using ValTypeC = int32_t; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_48,_64>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 64>; - using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout< 48, 64>; - using CLayout = GMMA::CLayout_64x48; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct MMA_Traits> -{ - using ValTypeD = int32_t; - using ValTypeA = sparse_elem<2, uint8_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = int8_t; - using ValTypeC = int32_t; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_64,_64>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 64>; - using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout< 64, 64>; - using CLayout = GMMA::CLayout_64x64; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct MMA_Traits> -{ - using ValTypeD = int32_t; - using ValTypeA = sparse_elem<2, uint8_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = int8_t; - using ValTypeC = int32_t; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_64,_64>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 64>; - using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout< 64, 64>; - using CLayout = GMMA::CLayout_64x64; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template -struct MMA_Traits> -{ - using ValTypeD = int32_t; - using ValTypeA = sparse_elem<2, uint8_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = int8_t; - using ValTypeC = int32_t; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_80,_64>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 64>; - using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout< 80, 64>; - using CLayout = GMMA::CLayout_64x80; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template -struct MMA_Traits> -{ - using ValTypeD = int32_t; - using ValTypeA = sparse_elem<2, uint8_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = int8_t; - using ValTypeC = int32_t; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_80,_64>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 64>; - using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout< 80, 64>; - using CLayout = GMMA::CLayout_64x80; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct MMA_Traits> -{ - using ValTypeD = int32_t; - using ValTypeA = sparse_elem<2, uint8_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = int8_t; - using ValTypeC = int32_t; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_96,_64>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 64>; - using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout< 96, 64>; - using CLayout = GMMA::CLayout_64x96; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct MMA_Traits> -{ - using ValTypeD = int32_t; - using ValTypeA = sparse_elem<2, uint8_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = int8_t; - using ValTypeC = int32_t; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_96,_64>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 64>; - using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout< 96, 64>; - using CLayout = GMMA::CLayout_64x96; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template -struct MMA_Traits> -{ - using ValTypeD = int32_t; - using ValTypeA = sparse_elem<2, uint8_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = int8_t; - using ValTypeC = int32_t; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_112,_64>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 64>; - using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<112, 64>; - using CLayout = GMMA::CLayout_64x112; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template -struct MMA_Traits> -{ - using ValTypeD = int32_t; - using ValTypeA = sparse_elem<2, uint8_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = int8_t; - using ValTypeC = int32_t; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_112,_64>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 64>; - using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<112, 64>; - using CLayout = GMMA::CLayout_64x112; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct MMA_Traits> -{ - using ValTypeD = int32_t; - using ValTypeA = sparse_elem<2, uint8_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = int8_t; - using ValTypeC = int32_t; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_128,_64>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 64>; - using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<128, 64>; - using CLayout = GMMA::CLayout_64x128; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct MMA_Traits> -{ - using ValTypeD = int32_t; - using ValTypeA = sparse_elem<2, uint8_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = int8_t; - using ValTypeC = int32_t; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_128,_64>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 64>; - using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<128, 64>; - using CLayout = GMMA::CLayout_64x128; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template -struct MMA_Traits> -{ - using ValTypeD = int32_t; - using ValTypeA = sparse_elem<2, uint8_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = int8_t; - using ValTypeC = int32_t; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_144,_64>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 64>; - using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<144, 64>; - using CLayout = GMMA::CLayout_64x144; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template -struct MMA_Traits> -{ - using ValTypeD = int32_t; - using ValTypeA = sparse_elem<2, uint8_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = int8_t; - using ValTypeC = int32_t; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_144,_64>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 64>; - using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<144, 64>; - using CLayout = GMMA::CLayout_64x144; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template -struct MMA_Traits> -{ - using ValTypeD = int32_t; - using ValTypeA = sparse_elem<2, uint8_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = int8_t; - using ValTypeC = int32_t; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_160,_64>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 64>; - using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<160, 64>; - using CLayout = GMMA::CLayout_64x160; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template -struct MMA_Traits> -{ - using ValTypeD = int32_t; - using ValTypeA = sparse_elem<2, uint8_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = int8_t; - using ValTypeC = int32_t; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_160,_64>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 64>; - using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<160, 64>; - using CLayout = GMMA::CLayout_64x160; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template -struct MMA_Traits> -{ - using ValTypeD = int32_t; - using ValTypeA = sparse_elem<2, uint8_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = int8_t; - using ValTypeC = int32_t; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_176,_64>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 64>; - using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<176, 64>; - using CLayout = GMMA::CLayout_64x176; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template -struct MMA_Traits> -{ - using ValTypeD = int32_t; - using ValTypeA = sparse_elem<2, uint8_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = int8_t; - using ValTypeC = int32_t; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_176,_64>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 64>; - using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<176, 64>; - using CLayout = GMMA::CLayout_64x176; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct MMA_Traits> -{ - using ValTypeD = int32_t; - using ValTypeA = sparse_elem<2, uint8_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = int8_t; - using ValTypeC = int32_t; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_192,_64>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 64>; - using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<192, 64>; - using CLayout = GMMA::CLayout_64x192; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct MMA_Traits> -{ - using ValTypeD = int32_t; - using ValTypeA = sparse_elem<2, uint8_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = int8_t; - using ValTypeC = int32_t; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_192,_64>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 64>; - using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<192, 64>; - using CLayout = GMMA::CLayout_64x192; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template -struct MMA_Traits> -{ - using ValTypeD = int32_t; - using ValTypeA = sparse_elem<2, uint8_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = int8_t; - using ValTypeC = int32_t; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_208,_64>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 64>; - using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<208, 64>; - using CLayout = GMMA::CLayout_64x208; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template -struct MMA_Traits> -{ - using ValTypeD = int32_t; - using ValTypeA = sparse_elem<2, uint8_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = int8_t; - using ValTypeC = int32_t; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_208,_64>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 64>; - using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<208, 64>; - using CLayout = GMMA::CLayout_64x208; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template -struct MMA_Traits> -{ - using ValTypeD = int32_t; - using ValTypeA = sparse_elem<2, uint8_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = int8_t; - using ValTypeC = int32_t; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_224,_64>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 64>; - using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<224, 64>; - using CLayout = GMMA::CLayout_64x224; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template -struct MMA_Traits> -{ - using ValTypeD = int32_t; - using ValTypeA = sparse_elem<2, uint8_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = int8_t; - using ValTypeC = int32_t; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_224,_64>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 64>; - using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<224, 64>; - using CLayout = GMMA::CLayout_64x224; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template -struct MMA_Traits> -{ - using ValTypeD = int32_t; - using ValTypeA = sparse_elem<2, uint8_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = int8_t; - using ValTypeC = int32_t; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_240,_64>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 64>; - using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<240, 64>; - using CLayout = GMMA::CLayout_64x240; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template -struct MMA_Traits> -{ - using ValTypeD = int32_t; - using ValTypeA = sparse_elem<2, uint8_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = int8_t; - using ValTypeC = int32_t; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_240,_64>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 64>; - using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<240, 64>; - using CLayout = GMMA::CLayout_64x240; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct MMA_Traits> -{ - using ValTypeD = int32_t; - using ValTypeA = sparse_elem<2, uint8_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = int8_t; - using ValTypeC = int32_t; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_256,_64>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 64>; - using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<256, 64>; - using CLayout = GMMA::CLayout_64x256; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct MMA_Traits> -{ - using ValTypeD = int32_t; - using ValTypeA = sparse_elem<2, uint8_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = int8_t; - using ValTypeC = int32_t; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_256,_64>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 64>; - using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<256, 64>; - using CLayout = GMMA::CLayout_64x256; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct MMA_Traits> -{ - using ValTypeD = int32_t; - using ValTypeA = sparse_elem<2, uint8_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = int8_t; - using ValTypeC = int32_t; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_8,_64>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x64; - using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout< 8, 64>; - using CLayout = GMMA::CLayout_64x8; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct MMA_Traits> -{ - using ValTypeD = int32_t; - using ValTypeA = sparse_elem<2, uint8_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = int8_t; - using ValTypeC = int32_t; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_8,_64>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x64; - using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout< 8, 64>; - using CLayout = GMMA::CLayout_64x8; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct MMA_Traits> -{ - using ValTypeD = int32_t; - using ValTypeA = sparse_elem<2, uint8_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = int8_t; - using ValTypeC = int32_t; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_16,_64>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x64; - using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout< 16, 64>; - using CLayout = GMMA::CLayout_64x16; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct MMA_Traits> -{ - using ValTypeD = int32_t; - using ValTypeA = sparse_elem<2, uint8_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = int8_t; - using ValTypeC = int32_t; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_16,_64>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x64; - using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout< 16, 64>; - using CLayout = GMMA::CLayout_64x16; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct MMA_Traits> -{ - using ValTypeD = int32_t; - using ValTypeA = sparse_elem<2, uint8_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = int8_t; - using ValTypeC = int32_t; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_32,_64>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x64; - using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout< 32, 64>; - using CLayout = GMMA::CLayout_64x32; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct MMA_Traits> -{ - using ValTypeD = int32_t; - using ValTypeA = sparse_elem<2, uint8_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = int8_t; - using ValTypeC = int32_t; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_32,_64>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x64; - using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout< 32, 64>; - using CLayout = GMMA::CLayout_64x32; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template -struct MMA_Traits> -{ - using ValTypeD = int32_t; - using ValTypeA = sparse_elem<2, uint8_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = int8_t; - using ValTypeC = int32_t; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_48,_64>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x64; - using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout< 48, 64>; - using CLayout = GMMA::CLayout_64x48; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template -struct MMA_Traits> -{ - using ValTypeD = int32_t; - using ValTypeA = sparse_elem<2, uint8_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = int8_t; - using ValTypeC = int32_t; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_48,_64>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x64; - using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout< 48, 64>; - using CLayout = GMMA::CLayout_64x48; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct MMA_Traits> -{ - using ValTypeD = int32_t; - using ValTypeA = sparse_elem<2, uint8_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = int8_t; - using ValTypeC = int32_t; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_64,_64>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x64; - using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout< 64, 64>; - using CLayout = GMMA::CLayout_64x64; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct MMA_Traits> -{ - using ValTypeD = int32_t; - using ValTypeA = sparse_elem<2, uint8_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = int8_t; - using ValTypeC = int32_t; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_64,_64>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x64; - using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout< 64, 64>; - using CLayout = GMMA::CLayout_64x64; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template -struct MMA_Traits> -{ - using ValTypeD = int32_t; - using ValTypeA = sparse_elem<2, uint8_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = int8_t; - using ValTypeC = int32_t; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_80,_64>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x64; - using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout< 80, 64>; - using CLayout = GMMA::CLayout_64x80; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template -struct MMA_Traits> -{ - using ValTypeD = int32_t; - using ValTypeA = sparse_elem<2, uint8_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = int8_t; - using ValTypeC = int32_t; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_80,_64>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x64; - using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout< 80, 64>; - using CLayout = GMMA::CLayout_64x80; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct MMA_Traits> -{ - using ValTypeD = int32_t; - using ValTypeA = sparse_elem<2, uint8_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = int8_t; - using ValTypeC = int32_t; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_96,_64>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x64; - using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout< 96, 64>; - using CLayout = GMMA::CLayout_64x96; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct MMA_Traits> -{ - using ValTypeD = int32_t; - using ValTypeA = sparse_elem<2, uint8_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = int8_t; - using ValTypeC = int32_t; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_96,_64>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x64; - using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout< 96, 64>; - using CLayout = GMMA::CLayout_64x96; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template -struct MMA_Traits> -{ - using ValTypeD = int32_t; - using ValTypeA = sparse_elem<2, uint8_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = int8_t; - using ValTypeC = int32_t; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_112,_64>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x64; - using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<112, 64>; - using CLayout = GMMA::CLayout_64x112; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template -struct MMA_Traits> -{ - using ValTypeD = int32_t; - using ValTypeA = sparse_elem<2, uint8_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = int8_t; - using ValTypeC = int32_t; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_112,_64>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x64; - using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<112, 64>; - using CLayout = GMMA::CLayout_64x112; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct MMA_Traits> -{ - using ValTypeD = int32_t; - using ValTypeA = sparse_elem<2, uint8_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = int8_t; - using ValTypeC = int32_t; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_128,_64>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x64; - using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<128, 64>; - using CLayout = GMMA::CLayout_64x128; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct MMA_Traits> -{ - using ValTypeD = int32_t; - using ValTypeA = sparse_elem<2, uint8_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = int8_t; - using ValTypeC = int32_t; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_128,_64>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x64; - using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<128, 64>; - using CLayout = GMMA::CLayout_64x128; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template -struct MMA_Traits> -{ - using ValTypeD = int32_t; - using ValTypeA = sparse_elem<2, uint8_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = int8_t; - using ValTypeC = int32_t; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_144,_64>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x64; - using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<144, 64>; - using CLayout = GMMA::CLayout_64x144; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template -struct MMA_Traits> -{ - using ValTypeD = int32_t; - using ValTypeA = sparse_elem<2, uint8_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = int8_t; - using ValTypeC = int32_t; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_144,_64>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x64; - using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<144, 64>; - using CLayout = GMMA::CLayout_64x144; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template -struct MMA_Traits> -{ - using ValTypeD = int32_t; - using ValTypeA = sparse_elem<2, uint8_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = int8_t; - using ValTypeC = int32_t; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_160,_64>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x64; - using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<160, 64>; - using CLayout = GMMA::CLayout_64x160; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template -struct MMA_Traits> -{ - using ValTypeD = int32_t; - using ValTypeA = sparse_elem<2, uint8_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = int8_t; - using ValTypeC = int32_t; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_160,_64>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x64; - using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<160, 64>; - using CLayout = GMMA::CLayout_64x160; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template -struct MMA_Traits> -{ - using ValTypeD = int32_t; - using ValTypeA = sparse_elem<2, uint8_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = int8_t; - using ValTypeC = int32_t; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_176,_64>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x64; - using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<176, 64>; - using CLayout = GMMA::CLayout_64x176; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template -struct MMA_Traits> -{ - using ValTypeD = int32_t; - using ValTypeA = sparse_elem<2, uint8_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = int8_t; - using ValTypeC = int32_t; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_176,_64>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x64; - using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<176, 64>; - using CLayout = GMMA::CLayout_64x176; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct MMA_Traits> -{ - using ValTypeD = int32_t; - using ValTypeA = sparse_elem<2, uint8_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = int8_t; - using ValTypeC = int32_t; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_192,_64>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x64; - using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<192, 64>; - using CLayout = GMMA::CLayout_64x192; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct MMA_Traits> -{ - using ValTypeD = int32_t; - using ValTypeA = sparse_elem<2, uint8_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = int8_t; - using ValTypeC = int32_t; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_192,_64>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x64; - using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<192, 64>; - using CLayout = GMMA::CLayout_64x192; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template -struct MMA_Traits> -{ - using ValTypeD = int32_t; - using ValTypeA = sparse_elem<2, uint8_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = int8_t; - using ValTypeC = int32_t; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_208,_64>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x64; - using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<208, 64>; - using CLayout = GMMA::CLayout_64x208; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template -struct MMA_Traits> -{ - using ValTypeD = int32_t; - using ValTypeA = sparse_elem<2, uint8_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = int8_t; - using ValTypeC = int32_t; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_208,_64>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x64; - using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<208, 64>; - using CLayout = GMMA::CLayout_64x208; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template -struct MMA_Traits> -{ - using ValTypeD = int32_t; - using ValTypeA = sparse_elem<2, uint8_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = int8_t; - using ValTypeC = int32_t; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_224,_64>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x64; - using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<224, 64>; - using CLayout = GMMA::CLayout_64x224; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template -struct MMA_Traits> -{ - using ValTypeD = int32_t; - using ValTypeA = sparse_elem<2, uint8_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = int8_t; - using ValTypeC = int32_t; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_224,_64>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x64; - using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<224, 64>; - using CLayout = GMMA::CLayout_64x224; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template -struct MMA_Traits> -{ - using ValTypeD = int32_t; - using ValTypeA = sparse_elem<2, uint8_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = int8_t; - using ValTypeC = int32_t; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_240,_64>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x64; - using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<240, 64>; - using CLayout = GMMA::CLayout_64x240; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template -struct MMA_Traits> -{ - using ValTypeD = int32_t; - using ValTypeA = sparse_elem<2, uint8_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = int8_t; - using ValTypeC = int32_t; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_240,_64>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x64; - using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<240, 64>; - using CLayout = GMMA::CLayout_64x240; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct MMA_Traits> -{ - using ValTypeD = int32_t; - using ValTypeA = sparse_elem<2, uint8_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = int8_t; - using ValTypeC = int32_t; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_256,_64>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x64; - using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<256, 64>; - using CLayout = GMMA::CLayout_64x256; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct MMA_Traits> -{ - using ValTypeD = int32_t; - using ValTypeA = sparse_elem<2, uint8_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = int8_t; - using ValTypeC = int32_t; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_256,_64>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x64; - using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<256, 64>; - using CLayout = GMMA::CLayout_64x256; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct MMA_Traits> -{ - using ValTypeD = int32_t; - using ValTypeA = sparse_elem<2, uint8_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = uint8_t; - using ValTypeC = int32_t; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_8,_64>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 64>; - using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout< 8, 64>; - using CLayout = GMMA::CLayout_64x8; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct MMA_Traits> -{ - using ValTypeD = int32_t; - using ValTypeA = sparse_elem<2, uint8_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = uint8_t; - using ValTypeC = int32_t; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_8,_64>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 64>; - using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout< 8, 64>; - using CLayout = GMMA::CLayout_64x8; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct MMA_Traits> -{ - using ValTypeD = int32_t; - using ValTypeA = sparse_elem<2, uint8_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = uint8_t; - using ValTypeC = int32_t; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_16,_64>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 64>; - using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout< 16, 64>; - using CLayout = GMMA::CLayout_64x16; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct MMA_Traits> -{ - using ValTypeD = int32_t; - using ValTypeA = sparse_elem<2, uint8_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = uint8_t; - using ValTypeC = int32_t; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_16,_64>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 64>; - using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout< 16, 64>; - using CLayout = GMMA::CLayout_64x16; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct MMA_Traits> -{ - using ValTypeD = int32_t; - using ValTypeA = sparse_elem<2, uint8_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = uint8_t; - using ValTypeC = int32_t; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_32,_64>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 64>; - using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout< 32, 64>; - using CLayout = GMMA::CLayout_64x32; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct MMA_Traits> -{ - using ValTypeD = int32_t; - using ValTypeA = sparse_elem<2, uint8_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = uint8_t; - using ValTypeC = int32_t; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_32,_64>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 64>; - using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout< 32, 64>; - using CLayout = GMMA::CLayout_64x32; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template -struct MMA_Traits> -{ - using ValTypeD = int32_t; - using ValTypeA = sparse_elem<2, uint8_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = uint8_t; - using ValTypeC = int32_t; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_48,_64>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 64>; - using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout< 48, 64>; - using CLayout = GMMA::CLayout_64x48; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template -struct MMA_Traits> -{ - using ValTypeD = int32_t; - using ValTypeA = sparse_elem<2, uint8_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = uint8_t; - using ValTypeC = int32_t; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_48,_64>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 64>; - using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout< 48, 64>; - using CLayout = GMMA::CLayout_64x48; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct MMA_Traits> -{ - using ValTypeD = int32_t; - using ValTypeA = sparse_elem<2, uint8_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = uint8_t; - using ValTypeC = int32_t; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_64,_64>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 64>; - using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout< 64, 64>; - using CLayout = GMMA::CLayout_64x64; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct MMA_Traits> -{ - using ValTypeD = int32_t; - using ValTypeA = sparse_elem<2, uint8_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = uint8_t; - using ValTypeC = int32_t; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_64,_64>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 64>; - using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout< 64, 64>; - using CLayout = GMMA::CLayout_64x64; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template -struct MMA_Traits> -{ - using ValTypeD = int32_t; - using ValTypeA = sparse_elem<2, uint8_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = uint8_t; - using ValTypeC = int32_t; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_80,_64>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 64>; - using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout< 80, 64>; - using CLayout = GMMA::CLayout_64x80; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template -struct MMA_Traits> -{ - using ValTypeD = int32_t; - using ValTypeA = sparse_elem<2, uint8_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = uint8_t; - using ValTypeC = int32_t; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_80,_64>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 64>; - using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout< 80, 64>; - using CLayout = GMMA::CLayout_64x80; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct MMA_Traits> -{ - using ValTypeD = int32_t; - using ValTypeA = sparse_elem<2, uint8_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = uint8_t; - using ValTypeC = int32_t; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_96,_64>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 64>; - using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout< 96, 64>; - using CLayout = GMMA::CLayout_64x96; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct MMA_Traits> -{ - using ValTypeD = int32_t; - using ValTypeA = sparse_elem<2, uint8_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = uint8_t; - using ValTypeC = int32_t; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_96,_64>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 64>; - using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout< 96, 64>; - using CLayout = GMMA::CLayout_64x96; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template -struct MMA_Traits> -{ - using ValTypeD = int32_t; - using ValTypeA = sparse_elem<2, uint8_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = uint8_t; - using ValTypeC = int32_t; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_112,_64>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 64>; - using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<112, 64>; - using CLayout = GMMA::CLayout_64x112; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template -struct MMA_Traits> -{ - using ValTypeD = int32_t; - using ValTypeA = sparse_elem<2, uint8_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = uint8_t; - using ValTypeC = int32_t; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_112,_64>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 64>; - using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<112, 64>; - using CLayout = GMMA::CLayout_64x112; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct MMA_Traits> -{ - using ValTypeD = int32_t; - using ValTypeA = sparse_elem<2, uint8_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = uint8_t; - using ValTypeC = int32_t; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_128,_64>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 64>; - using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<128, 64>; - using CLayout = GMMA::CLayout_64x128; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct MMA_Traits> -{ - using ValTypeD = int32_t; - using ValTypeA = sparse_elem<2, uint8_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = uint8_t; - using ValTypeC = int32_t; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_128,_64>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 64>; - using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<128, 64>; - using CLayout = GMMA::CLayout_64x128; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template -struct MMA_Traits> -{ - using ValTypeD = int32_t; - using ValTypeA = sparse_elem<2, uint8_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = uint8_t; - using ValTypeC = int32_t; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_144,_64>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 64>; - using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<144, 64>; - using CLayout = GMMA::CLayout_64x144; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template -struct MMA_Traits> -{ - using ValTypeD = int32_t; - using ValTypeA = sparse_elem<2, uint8_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = uint8_t; - using ValTypeC = int32_t; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_144,_64>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 64>; - using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<144, 64>; - using CLayout = GMMA::CLayout_64x144; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template -struct MMA_Traits> -{ - using ValTypeD = int32_t; - using ValTypeA = sparse_elem<2, uint8_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = uint8_t; - using ValTypeC = int32_t; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_160,_64>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 64>; - using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<160, 64>; - using CLayout = GMMA::CLayout_64x160; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template -struct MMA_Traits> -{ - using ValTypeD = int32_t; - using ValTypeA = sparse_elem<2, uint8_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = uint8_t; - using ValTypeC = int32_t; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_160,_64>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 64>; - using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<160, 64>; - using CLayout = GMMA::CLayout_64x160; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template -struct MMA_Traits> -{ - using ValTypeD = int32_t; - using ValTypeA = sparse_elem<2, uint8_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = uint8_t; - using ValTypeC = int32_t; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_176,_64>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 64>; - using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<176, 64>; - using CLayout = GMMA::CLayout_64x176; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template -struct MMA_Traits> -{ - using ValTypeD = int32_t; - using ValTypeA = sparse_elem<2, uint8_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = uint8_t; - using ValTypeC = int32_t; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_176,_64>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 64>; - using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<176, 64>; - using CLayout = GMMA::CLayout_64x176; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct MMA_Traits> -{ - using ValTypeD = int32_t; - using ValTypeA = sparse_elem<2, uint8_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = uint8_t; - using ValTypeC = int32_t; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_192,_64>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 64>; - using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<192, 64>; - using CLayout = GMMA::CLayout_64x192; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct MMA_Traits> -{ - using ValTypeD = int32_t; - using ValTypeA = sparse_elem<2, uint8_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = uint8_t; - using ValTypeC = int32_t; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_192,_64>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 64>; - using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<192, 64>; - using CLayout = GMMA::CLayout_64x192; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template -struct MMA_Traits> -{ - using ValTypeD = int32_t; - using ValTypeA = sparse_elem<2, uint8_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = uint8_t; - using ValTypeC = int32_t; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_208,_64>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 64>; - using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<208, 64>; - using CLayout = GMMA::CLayout_64x208; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template -struct MMA_Traits> -{ - using ValTypeD = int32_t; - using ValTypeA = sparse_elem<2, uint8_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = uint8_t; - using ValTypeC = int32_t; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_208,_64>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 64>; - using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<208, 64>; - using CLayout = GMMA::CLayout_64x208; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template -struct MMA_Traits> -{ - using ValTypeD = int32_t; - using ValTypeA = sparse_elem<2, uint8_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = uint8_t; - using ValTypeC = int32_t; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_224,_64>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 64>; - using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<224, 64>; - using CLayout = GMMA::CLayout_64x224; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template -struct MMA_Traits> -{ - using ValTypeD = int32_t; - using ValTypeA = sparse_elem<2, uint8_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = uint8_t; - using ValTypeC = int32_t; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_224,_64>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 64>; - using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<224, 64>; - using CLayout = GMMA::CLayout_64x224; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template -struct MMA_Traits> -{ - using ValTypeD = int32_t; - using ValTypeA = sparse_elem<2, uint8_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = uint8_t; - using ValTypeC = int32_t; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_240,_64>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 64>; - using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<240, 64>; - using CLayout = GMMA::CLayout_64x240; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template -struct MMA_Traits> -{ - using ValTypeD = int32_t; - using ValTypeA = sparse_elem<2, uint8_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = uint8_t; - using ValTypeC = int32_t; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_240,_64>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 64>; - using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<240, 64>; - using CLayout = GMMA::CLayout_64x240; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct MMA_Traits> -{ - using ValTypeD = int32_t; - using ValTypeA = sparse_elem<2, uint8_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = uint8_t; - using ValTypeC = int32_t; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_256,_64>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 64>; - using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<256, 64>; - using CLayout = GMMA::CLayout_64x256; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct MMA_Traits> -{ - using ValTypeD = int32_t; - using ValTypeA = sparse_elem<2, uint8_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = uint8_t; - using ValTypeC = int32_t; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_256,_64>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 64>; - using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<256, 64>; - using CLayout = GMMA::CLayout_64x256; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct MMA_Traits> -{ - using ValTypeD = int32_t; - using ValTypeA = sparse_elem<2, uint8_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = uint8_t; - using ValTypeC = int32_t; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_8,_64>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x64; - using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout< 8, 64>; - using CLayout = GMMA::CLayout_64x8; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct MMA_Traits> -{ - using ValTypeD = int32_t; - using ValTypeA = sparse_elem<2, uint8_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = uint8_t; - using ValTypeC = int32_t; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_8,_64>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x64; - using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout< 8, 64>; - using CLayout = GMMA::CLayout_64x8; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct MMA_Traits> -{ - using ValTypeD = int32_t; - using ValTypeA = sparse_elem<2, uint8_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = uint8_t; - using ValTypeC = int32_t; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_16,_64>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x64; - using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout< 16, 64>; - using CLayout = GMMA::CLayout_64x16; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct MMA_Traits> -{ - using ValTypeD = int32_t; - using ValTypeA = sparse_elem<2, uint8_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = uint8_t; - using ValTypeC = int32_t; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_16,_64>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x64; - using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout< 16, 64>; - using CLayout = GMMA::CLayout_64x16; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct MMA_Traits> -{ - using ValTypeD = int32_t; - using ValTypeA = sparse_elem<2, uint8_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = uint8_t; - using ValTypeC = int32_t; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_32,_64>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x64; - using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout< 32, 64>; - using CLayout = GMMA::CLayout_64x32; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct MMA_Traits> -{ - using ValTypeD = int32_t; - using ValTypeA = sparse_elem<2, uint8_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = uint8_t; - using ValTypeC = int32_t; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_32,_64>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x64; - using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout< 32, 64>; - using CLayout = GMMA::CLayout_64x32; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template -struct MMA_Traits> -{ - using ValTypeD = int32_t; - using ValTypeA = sparse_elem<2, uint8_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = uint8_t; - using ValTypeC = int32_t; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_48,_64>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x64; - using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout< 48, 64>; - using CLayout = GMMA::CLayout_64x48; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template -struct MMA_Traits> -{ - using ValTypeD = int32_t; - using ValTypeA = sparse_elem<2, uint8_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = uint8_t; - using ValTypeC = int32_t; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_48,_64>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x64; - using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout< 48, 64>; - using CLayout = GMMA::CLayout_64x48; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct MMA_Traits> -{ - using ValTypeD = int32_t; - using ValTypeA = sparse_elem<2, uint8_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = uint8_t; - using ValTypeC = int32_t; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_64,_64>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x64; - using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout< 64, 64>; - using CLayout = GMMA::CLayout_64x64; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct MMA_Traits> -{ - using ValTypeD = int32_t; - using ValTypeA = sparse_elem<2, uint8_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = uint8_t; - using ValTypeC = int32_t; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_64,_64>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x64; - using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout< 64, 64>; - using CLayout = GMMA::CLayout_64x64; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template -struct MMA_Traits> -{ - using ValTypeD = int32_t; - using ValTypeA = sparse_elem<2, uint8_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = uint8_t; - using ValTypeC = int32_t; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_80,_64>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x64; - using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout< 80, 64>; - using CLayout = GMMA::CLayout_64x80; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template -struct MMA_Traits> -{ - using ValTypeD = int32_t; - using ValTypeA = sparse_elem<2, uint8_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = uint8_t; - using ValTypeC = int32_t; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_80,_64>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x64; - using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout< 80, 64>; - using CLayout = GMMA::CLayout_64x80; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct MMA_Traits> -{ - using ValTypeD = int32_t; - using ValTypeA = sparse_elem<2, uint8_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = uint8_t; - using ValTypeC = int32_t; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_96,_64>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x64; - using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout< 96, 64>; - using CLayout = GMMA::CLayout_64x96; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct MMA_Traits> -{ - using ValTypeD = int32_t; - using ValTypeA = sparse_elem<2, uint8_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = uint8_t; - using ValTypeC = int32_t; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_96,_64>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x64; - using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout< 96, 64>; - using CLayout = GMMA::CLayout_64x96; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template -struct MMA_Traits> -{ - using ValTypeD = int32_t; - using ValTypeA = sparse_elem<2, uint8_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = uint8_t; - using ValTypeC = int32_t; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_112,_64>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x64; - using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<112, 64>; - using CLayout = GMMA::CLayout_64x112; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template -struct MMA_Traits> -{ - using ValTypeD = int32_t; - using ValTypeA = sparse_elem<2, uint8_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = uint8_t; - using ValTypeC = int32_t; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_112,_64>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x64; - using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<112, 64>; - using CLayout = GMMA::CLayout_64x112; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct MMA_Traits> -{ - using ValTypeD = int32_t; - using ValTypeA = sparse_elem<2, uint8_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = uint8_t; - using ValTypeC = int32_t; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_128,_64>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x64; - using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<128, 64>; - using CLayout = GMMA::CLayout_64x128; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct MMA_Traits> -{ - using ValTypeD = int32_t; - using ValTypeA = sparse_elem<2, uint8_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = uint8_t; - using ValTypeC = int32_t; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_128,_64>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x64; - using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<128, 64>; - using CLayout = GMMA::CLayout_64x128; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template -struct MMA_Traits> -{ - using ValTypeD = int32_t; - using ValTypeA = sparse_elem<2, uint8_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = uint8_t; - using ValTypeC = int32_t; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_144,_64>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x64; - using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<144, 64>; - using CLayout = GMMA::CLayout_64x144; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template -struct MMA_Traits> -{ - using ValTypeD = int32_t; - using ValTypeA = sparse_elem<2, uint8_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = uint8_t; - using ValTypeC = int32_t; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_144,_64>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x64; - using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<144, 64>; - using CLayout = GMMA::CLayout_64x144; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template -struct MMA_Traits> -{ - using ValTypeD = int32_t; - using ValTypeA = sparse_elem<2, uint8_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = uint8_t; - using ValTypeC = int32_t; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_160,_64>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x64; - using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<160, 64>; - using CLayout = GMMA::CLayout_64x160; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template -struct MMA_Traits> -{ - using ValTypeD = int32_t; - using ValTypeA = sparse_elem<2, uint8_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = uint8_t; - using ValTypeC = int32_t; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_160,_64>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x64; - using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<160, 64>; - using CLayout = GMMA::CLayout_64x160; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template -struct MMA_Traits> -{ - using ValTypeD = int32_t; - using ValTypeA = sparse_elem<2, uint8_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = uint8_t; - using ValTypeC = int32_t; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_176,_64>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x64; - using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<176, 64>; - using CLayout = GMMA::CLayout_64x176; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template -struct MMA_Traits> -{ - using ValTypeD = int32_t; - using ValTypeA = sparse_elem<2, uint8_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = uint8_t; - using ValTypeC = int32_t; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_176,_64>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x64; - using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<176, 64>; - using CLayout = GMMA::CLayout_64x176; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct MMA_Traits> -{ - using ValTypeD = int32_t; - using ValTypeA = sparse_elem<2, uint8_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = uint8_t; - using ValTypeC = int32_t; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_192,_64>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x64; - using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<192, 64>; - using CLayout = GMMA::CLayout_64x192; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct MMA_Traits> -{ - using ValTypeD = int32_t; - using ValTypeA = sparse_elem<2, uint8_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = uint8_t; - using ValTypeC = int32_t; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_192,_64>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x64; - using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<192, 64>; - using CLayout = GMMA::CLayout_64x192; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template -struct MMA_Traits> -{ - using ValTypeD = int32_t; - using ValTypeA = sparse_elem<2, uint8_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = uint8_t; - using ValTypeC = int32_t; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_208,_64>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x64; - using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<208, 64>; - using CLayout = GMMA::CLayout_64x208; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template -struct MMA_Traits> -{ - using ValTypeD = int32_t; - using ValTypeA = sparse_elem<2, uint8_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = uint8_t; - using ValTypeC = int32_t; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_208,_64>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x64; - using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<208, 64>; - using CLayout = GMMA::CLayout_64x208; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template -struct MMA_Traits> -{ - using ValTypeD = int32_t; - using ValTypeA = sparse_elem<2, uint8_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = uint8_t; - using ValTypeC = int32_t; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_224,_64>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x64; - using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<224, 64>; - using CLayout = GMMA::CLayout_64x224; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template -struct MMA_Traits> -{ - using ValTypeD = int32_t; - using ValTypeA = sparse_elem<2, uint8_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = uint8_t; - using ValTypeC = int32_t; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_224,_64>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x64; - using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<224, 64>; - using CLayout = GMMA::CLayout_64x224; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template -struct MMA_Traits> -{ - using ValTypeD = int32_t; - using ValTypeA = sparse_elem<2, uint8_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = uint8_t; - using ValTypeC = int32_t; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_240,_64>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x64; - using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<240, 64>; - using CLayout = GMMA::CLayout_64x240; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template -struct MMA_Traits> -{ - using ValTypeD = int32_t; - using ValTypeA = sparse_elem<2, uint8_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = uint8_t; - using ValTypeC = int32_t; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_240,_64>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x64; - using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<240, 64>; - using CLayout = GMMA::CLayout_64x240; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct MMA_Traits> -{ - using ValTypeD = int32_t; - using ValTypeA = sparse_elem<2, uint8_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = uint8_t; - using ValTypeC = int32_t; - + using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_256,_64>; + using Shape_MNK = Shape<_64,_64,_16>; using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x64; - using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<256, 64>; - using CLayout = GMMA::CLayout_64x256; + using ALayout = GMMA::ABLayout< 64, 16>; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout< 64, 16>; + using CLayout = GMMA::CLayout_64x64; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; //////////////////////////////////////////////////////////////////////////////////////////////////// -template -struct MMA_Traits> +template +struct MMA_Traits> { - using ValTypeD = int32_t; - using ValTypeA = sparse_elem<2, uint8_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = uint8_t; - using ValTypeC = int32_t; + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_256,_64>; + using Shape_MNK = Shape<_64,_64,_16>; using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x64; - using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<256, 64>; - using CLayout = GMMA::CLayout_64x256; + using ALayout = GMMA::ALayout_64x16; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout< 64, 16>; + using CLayout = GMMA::CLayout_64x64; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; @@ -10232,23 +1528,23 @@ struct MMA_Traits -struct MMA_Traits> +struct MMA_Traits> { - using ValTypeD = half_t; - using ValTypeA = sparse_elem<2, float_e4m3_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = float_e4m3_t; - using ValTypeC = half_t; + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_8,_64>; + using Shape_MNK = Shape<_64,_96,_16>; using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 64>; - using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout< 8, 64>; - using CLayout = GMMA::CLayout_64x8; + using ALayout = GMMA::ABLayout< 64, 16>; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout< 96, 16>; + using CLayout = GMMA::CLayout_64x96; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; @@ -10256,22 +1552,22 @@ struct MMA_Traits -struct MMA_Traits> +struct MMA_Traits> { - using ValTypeD = half_t; - using ValTypeA = sparse_elem<2, float_e4m3_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = float_e4m3_t; - using ValTypeC = half_t; + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_8,_64>; + using Shape_MNK = Shape<_64,_96,_16>; using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x64; - using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout< 8, 64>; - using CLayout = GMMA::CLayout_64x8; + using ALayout = GMMA::ALayout_64x16; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout< 96, 16>; + using CLayout = GMMA::CLayout_64x96; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; @@ -10279,23 +1575,23 @@ struct MMA_Traits -struct MMA_Traits> +struct MMA_Traits> { using ValTypeD = float; - using ValTypeA = sparse_elem<2, float_e4m3_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = float_e4m3_t; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; using ValTypeC = float; using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_8,_64>; + using Shape_MNK = Shape<_64,_128,_16>; using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 64>; - using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout< 8, 64>; - using CLayout = GMMA::CLayout_64x8; + using ALayout = GMMA::ABLayout< 64, 16>; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout<128, 16>; + using CLayout = GMMA::CLayout_64x128; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; @@ -10303,22 +1599,22 @@ struct MMA_Traits -struct MMA_Traits> -{ - using ValTypeD = float; - using ValTypeA = sparse_elem<2, float_e4m3_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = float_e4m3_t; +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; using ValTypeC = float; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_8,_64>; + using Shape_MNK = Shape<_64,_128,_16>; using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x64; - using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout< 8, 64>; - using CLayout = GMMA::CLayout_64x8; + using ALayout = GMMA::ALayout_64x16; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout<128, 16>; + using CLayout = GMMA::CLayout_64x128; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; @@ -10326,23 +1622,23 @@ struct MMA_Traits -struct MMA_Traits> +struct MMA_Traits> { - using ValTypeD = half_t; - using ValTypeA = sparse_elem<2, float_e4m3_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = float_e4m3_t; - using ValTypeC = half_t; + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_16,_64>; + using Shape_MNK = Shape<_64,_192,_16>; using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 64>; - using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout< 16, 64>; - using CLayout = GMMA::CLayout_64x16; + using ALayout = GMMA::ABLayout< 64, 16>; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout<192, 16>; + using CLayout = GMMA::CLayout_64x192; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; @@ -10350,22 +1646,22 @@ struct MMA_Traits -struct MMA_Traits> +struct MMA_Traits> { - using ValTypeD = half_t; - using ValTypeA = sparse_elem<2, float_e4m3_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = float_e4m3_t; - using ValTypeC = half_t; + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_16,_64>; + using Shape_MNK = Shape<_64,_192,_16>; using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x64; - using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout< 16, 64>; - using CLayout = GMMA::CLayout_64x16; + using ALayout = GMMA::ALayout_64x16; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout<192, 16>; + using CLayout = GMMA::CLayout_64x192; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; @@ -10373,23 +1669,23 @@ struct MMA_Traits -struct MMA_Traits> +struct MMA_Traits> { using ValTypeD = float; - using ValTypeA = sparse_elem<2, float_e4m3_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = float_e4m3_t; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; using ValTypeC = float; using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_16,_64>; + using Shape_MNK = Shape<_64,_256,_16>; using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 64>; - using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout< 16, 64>; - using CLayout = GMMA::CLayout_64x16; + using ALayout = GMMA::ABLayout< 64, 16>; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout<256, 16>; + using CLayout = GMMA::CLayout_64x256; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; @@ -10397,678 +1693,656 @@ struct MMA_Traits -struct MMA_Traits> +struct MMA_Traits> { using ValTypeD = float; - using ValTypeA = sparse_elem<2, float_e4m3_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = float_e4m3_t; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; using ValTypeC = float; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_16,_64>; + using Shape_MNK = Shape<_64,_256,_16>; using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x64; - using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout< 16, 64>; - using CLayout = GMMA::CLayout_64x16; + using ALayout = GMMA::ALayout_64x16; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout<256, 16>; + using CLayout = GMMA::CLayout_64x256; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; //////////////////////////////////////////////////////////////////////////////////////////////////// -template -struct MMA_Traits> +template +struct MMA_Traits> { - using ValTypeD = half_t; - using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = float_e4m3_t; - using ValTypeC = half_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_32,_64>; + using Shape_MNK = Shape<_64,_8,_64>; using ThrID = Layout<_128>; using ALayout = GMMA::ABLayout< 64, 64>; using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout< 32, 64>; - using CLayout = GMMA::CLayout_64x32; + using BLayout = GMMA::ABLayout< 8, 64>; + using CLayout = GMMA::CLayout_64x8; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; //////////////////////////////////////////////////////////////////////////////////////////////////// -template -struct MMA_Traits> +template +struct MMA_Traits> { - using ValTypeD = half_t; - using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = float_e4m3_t; - using ValTypeC = half_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_32,_64>; + using Shape_MNK = Shape<_64,_8,_64>; using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x64; + using ALayout = GMMA::ABLayout< 64, 64>; using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout< 32, 64>; - using CLayout = GMMA::CLayout_64x32; + using BLayout = GMMA::ABLayout< 8, 64>; + using CLayout = GMMA::CLayout_64x8; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; //////////////////////////////////////////////////////////////////////////////////////////////////// -template -struct MMA_Traits> +template +struct MMA_Traits> { - using ValTypeD = float; - using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = float_e4m3_t; - using ValTypeC = float; + using ValTypeB = int8_t; + using ValTypeC = int32_t; using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_32,_64>; + using Shape_MNK = Shape<_64,_16,_64>; using ThrID = Layout<_128>; using ALayout = GMMA::ABLayout< 64, 64>; using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout< 32, 64>; - using CLayout = GMMA::CLayout_64x32; + using BLayout = GMMA::ABLayout< 16, 64>; + using CLayout = GMMA::CLayout_64x16; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; //////////////////////////////////////////////////////////////////////////////////////////////////// -template -struct MMA_Traits> +template +struct MMA_Traits> { - using ValTypeD = float; - using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = float_e4m3_t; - using ValTypeC = float; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_32,_64>; + using Shape_MNK = Shape<_64,_16,_64>; using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x64; + using ALayout = GMMA::ABLayout< 64, 64>; using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout< 32, 64>; - using CLayout = GMMA::CLayout_64x32; + using BLayout = GMMA::ABLayout< 16, 64>; + using CLayout = GMMA::CLayout_64x16; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template -struct MMA_Traits> +template +struct MMA_Traits> { - using ValTypeD = half_t; - using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = float_e4m3_t; - using ValTypeC = half_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_48,_64>; + using Shape_MNK = Shape<_64,_32,_64>; using ThrID = Layout<_128>; using ALayout = GMMA::ABLayout< 64, 64>; using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout< 48, 64>; - using CLayout = GMMA::CLayout_64x48; + using BLayout = GMMA::ABLayout< 32, 64>; + using CLayout = GMMA::CLayout_64x32; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template -struct MMA_Traits> +template +struct MMA_Traits> { - using ValTypeD = half_t; - using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = float_e4m3_t; - using ValTypeC = half_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_48,_64>; + using Shape_MNK = Shape<_64,_32,_64>; using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x64; + using ALayout = GMMA::ABLayout< 64, 64>; using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout< 48, 64>; - using CLayout = GMMA::CLayout_64x48; + using BLayout = GMMA::ABLayout< 32, 64>; + using CLayout = GMMA::CLayout_64x32; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template -struct MMA_Traits> +template +struct MMA_Traits> { - using ValTypeD = float; - using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = float_e4m3_t; - using ValTypeC = float; + using ValTypeB = int8_t; + using ValTypeC = int32_t; using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_48,_64>; + using Shape_MNK = Shape<_64,_64,_64>; using ThrID = Layout<_128>; using ALayout = GMMA::ABLayout< 64, 64>; using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout< 48, 64>; - using CLayout = GMMA::CLayout_64x48; + using BLayout = GMMA::ABLayout< 64, 64>; + using CLayout = GMMA::CLayout_64x64; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template -struct MMA_Traits> +template +struct MMA_Traits> { - using ValTypeD = float; - using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = float_e4m3_t; - using ValTypeC = float; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_48,_64>; + using Shape_MNK = Shape<_64,_64,_64>; using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x64; + using ALayout = GMMA::ABLayout< 64, 64>; using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout< 48, 64>; - using CLayout = GMMA::CLayout_64x48; + using BLayout = GMMA::ABLayout< 64, 64>; + using CLayout = GMMA::CLayout_64x64; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -template -struct MMA_Traits> +template +struct MMA_Traits> { - using ValTypeD = half_t; - using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = float_e4m3_t; - using ValTypeC = half_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_64,_64>; + using Shape_MNK = Shape<_64,_96,_64>; using ThrID = Layout<_128>; using ALayout = GMMA::ABLayout< 64, 64>; - using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout< 64, 64>; - using CLayout = GMMA::CLayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 96, 64>; + using CLayout = GMMA::CLayout_64x96; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; //////////////////////////////////////////////////////////////////////////////////////////////////// -template -struct MMA_Traits> +template +struct MMA_Traits> { - using ValTypeD = half_t; - using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = float_e4m3_t; - using ValTypeC = half_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_64,_64>; + using Shape_MNK = Shape<_64,_96,_64>; using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x64; + using ALayout = GMMA::ABLayout< 64, 64>; using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout< 64, 64>; - using CLayout = GMMA::CLayout_64x64; + using BLayout = GMMA::ABLayout< 96, 64>; + using CLayout = GMMA::CLayout_64x96; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; //////////////////////////////////////////////////////////////////////////////////////////////////// -template -struct MMA_Traits> +template +struct MMA_Traits> { - using ValTypeD = float; - using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = float_e4m3_t; - using ValTypeC = float; + using ValTypeB = int8_t; + using ValTypeC = int32_t; using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_64,_64>; + using Shape_MNK = Shape<_64,_128,_64>; using ThrID = Layout<_128>; using ALayout = GMMA::ABLayout< 64, 64>; using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout< 64, 64>; - using CLayout = GMMA::CLayout_64x64; + using BLayout = GMMA::ABLayout<128, 64>; + using CLayout = GMMA::CLayout_64x128; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; //////////////////////////////////////////////////////////////////////////////////////////////////// -template -struct MMA_Traits> +template +struct MMA_Traits> { - using ValTypeD = float; - using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = float_e4m3_t; - using ValTypeC = float; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_64,_64>; + using Shape_MNK = Shape<_64,_128,_64>; using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x64; + using ALayout = GMMA::ABLayout< 64, 64>; using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout< 64, 64>; - using CLayout = GMMA::CLayout_64x64; + using BLayout = GMMA::ABLayout<128, 64>; + using CLayout = GMMA::CLayout_64x128; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template -struct MMA_Traits> +template +struct MMA_Traits> { - using ValTypeD = half_t; - using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = float_e4m3_t; - using ValTypeC = half_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_80,_64>; + using Shape_MNK = Shape<_64,_192,_64>; using ThrID = Layout<_128>; using ALayout = GMMA::ABLayout< 64, 64>; using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout< 80, 64>; - using CLayout = GMMA::CLayout_64x80; + using BLayout = GMMA::ABLayout<192, 64>; + using CLayout = GMMA::CLayout_64x192; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template -struct MMA_Traits> +template +struct MMA_Traits> { - using ValTypeD = half_t; - using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = float_e4m3_t; - using ValTypeC = half_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_80,_64>; + using Shape_MNK = Shape<_64,_192,_64>; using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x64; + using ALayout = GMMA::ABLayout< 64, 64>; using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout< 80, 64>; - using CLayout = GMMA::CLayout_64x80; + using BLayout = GMMA::ABLayout<192, 64>; + using CLayout = GMMA::CLayout_64x192; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template -struct MMA_Traits> +template +struct MMA_Traits> { - using ValTypeD = float; - using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = float_e4m3_t; - using ValTypeC = float; + using ValTypeB = int8_t; + using ValTypeC = int32_t; using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_80,_64>; + using Shape_MNK = Shape<_64,_256,_64>; using ThrID = Layout<_128>; using ALayout = GMMA::ABLayout< 64, 64>; using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout< 80, 64>; - using CLayout = GMMA::CLayout_64x80; + using BLayout = GMMA::ABLayout<256, 64>; + using CLayout = GMMA::CLayout_64x256; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template -struct MMA_Traits> +template +struct MMA_Traits> { - using ValTypeD = float; - using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = float_e4m3_t; - using ValTypeC = float; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_80,_64>; + using Shape_MNK = Shape<_64,_256,_64>; using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x64; + using ALayout = GMMA::ABLayout< 64, 64>; using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout< 80, 64>; - using CLayout = GMMA::CLayout_64x80; + using BLayout = GMMA::ABLayout<256, 64>; + using CLayout = GMMA::CLayout_64x256; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -template -struct MMA_Traits> +template +struct MMA_Traits> { - using ValTypeD = half_t; - using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = float_e4m3_t; - using ValTypeC = half_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; - using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_96,_64>; + using Shape_MNK = Shape<_64,_8,_64>; using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 64>; + using ALayout = GMMA::ALayout_64x64; using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout< 96, 64>; - using CLayout = GMMA::CLayout_64x96; + using BLayout = GMMA::ABLayout< 8, 64>; + using CLayout = GMMA::CLayout_64x8; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; //////////////////////////////////////////////////////////////////////////////////////////////////// -template -struct MMA_Traits> +template +struct MMA_Traits> { - using ValTypeD = half_t; - using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = float_e4m3_t; - using ValTypeC = half_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_96,_64>; + using Shape_MNK = Shape<_64,_8,_64>; using ThrID = Layout<_128>; using ALayout = GMMA::ALayout_64x64; using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout< 96, 64>; - using CLayout = GMMA::CLayout_64x96; + using BLayout = GMMA::ABLayout< 8, 64>; + using CLayout = GMMA::CLayout_64x8; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; //////////////////////////////////////////////////////////////////////////////////////////////////// -template -struct MMA_Traits> +template +struct MMA_Traits> { - using ValTypeD = float; - using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = float_e4m3_t; - using ValTypeC = float; + using ValTypeB = int8_t; + using ValTypeC = int32_t; - using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_96,_64>; + using Shape_MNK = Shape<_64,_16,_64>; using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 64>; + using ALayout = GMMA::ALayout_64x64; using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout< 96, 64>; - using CLayout = GMMA::CLayout_64x96; + using BLayout = GMMA::ABLayout< 16, 64>; + using CLayout = GMMA::CLayout_64x16; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; //////////////////////////////////////////////////////////////////////////////////////////////////// -template -struct MMA_Traits> +template +struct MMA_Traits> { - using ValTypeD = float; - using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = float_e4m3_t; - using ValTypeC = float; + using ValTypeB = int8_t; + using ValTypeC = int32_t; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_96,_64>; + using Shape_MNK = Shape<_64,_16,_64>; using ThrID = Layout<_128>; using ALayout = GMMA::ALayout_64x64; using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout< 96, 64>; - using CLayout = GMMA::CLayout_64x96; + using BLayout = GMMA::ABLayout< 16, 64>; + using CLayout = GMMA::CLayout_64x16; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template -struct MMA_Traits> +template +struct MMA_Traits> { - using ValTypeD = half_t; - using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = float_e4m3_t; - using ValTypeC = half_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; - using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_112,_64>; + using Shape_MNK = Shape<_64,_32,_64>; using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 64>; + using ALayout = GMMA::ALayout_64x64; using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<112, 64>; - using CLayout = GMMA::CLayout_64x112; + using BLayout = GMMA::ABLayout< 32, 64>; + using CLayout = GMMA::CLayout_64x32; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template -struct MMA_Traits> +template +struct MMA_Traits> { - using ValTypeD = half_t; - using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = float_e4m3_t; - using ValTypeC = half_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_112,_64>; + using Shape_MNK = Shape<_64,_32,_64>; using ThrID = Layout<_128>; using ALayout = GMMA::ALayout_64x64; using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<112, 64>; - using CLayout = GMMA::CLayout_64x112; + using BLayout = GMMA::ABLayout< 32, 64>; + using CLayout = GMMA::CLayout_64x32; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template -struct MMA_Traits> +template +struct MMA_Traits> { - using ValTypeD = float; - using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = float_e4m3_t; - using ValTypeC = float; + using ValTypeB = int8_t; + using ValTypeC = int32_t; - using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_112,_64>; + using Shape_MNK = Shape<_64,_64,_64>; using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 64>; + using ALayout = GMMA::ALayout_64x64; using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<112, 64>; - using CLayout = GMMA::CLayout_64x112; + using BLayout = GMMA::ABLayout< 64, 64>; + using CLayout = GMMA::CLayout_64x64; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template -struct MMA_Traits> +template +struct MMA_Traits> { - using ValTypeD = float; - using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = float_e4m3_t; - using ValTypeC = float; + using ValTypeB = int8_t; + using ValTypeC = int32_t; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_112,_64>; + using Shape_MNK = Shape<_64,_64,_64>; using ThrID = Layout<_128>; using ALayout = GMMA::ALayout_64x64; using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<112, 64>; - using CLayout = GMMA::CLayout_64x112; + using BLayout = GMMA::ABLayout< 64, 64>; + using CLayout = GMMA::CLayout_64x64; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -template -struct MMA_Traits> +template +struct MMA_Traits> { - using ValTypeD = half_t; - using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = float_e4m3_t; - using ValTypeC = half_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; - using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_128,_64>; + using Shape_MNK = Shape<_64,_96,_64>; using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 64>; + using ALayout = GMMA::ALayout_64x64; using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<128, 64>; - using CLayout = GMMA::CLayout_64x128; + using BLayout = GMMA::ABLayout< 96, 64>; + using CLayout = GMMA::CLayout_64x96; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; //////////////////////////////////////////////////////////////////////////////////////////////////// -template -struct MMA_Traits> +template +struct MMA_Traits> { - using ValTypeD = half_t; - using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = float_e4m3_t; - using ValTypeC = half_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_128,_64>; + using Shape_MNK = Shape<_64,_96,_64>; using ThrID = Layout<_128>; using ALayout = GMMA::ALayout_64x64; using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<128, 64>; - using CLayout = GMMA::CLayout_64x128; + using BLayout = GMMA::ABLayout< 96, 64>; + using CLayout = GMMA::CLayout_64x96; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; //////////////////////////////////////////////////////////////////////////////////////////////////// -template -struct MMA_Traits> +template +struct MMA_Traits> { - using ValTypeD = float; - using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = float_e4m3_t; - using ValTypeC = float; + using ValTypeB = int8_t; + using ValTypeC = int32_t; - using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; using Shape_MNK = Shape<_64,_128,_64>; using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 64>; + using ALayout = GMMA::ALayout_64x64; using ELayout = GMMA::ELayout_64x64; using BLayout = GMMA::ABLayout<128, 64>; using CLayout = GMMA::CLayout_64x128; @@ -11078,14 +2352,14 @@ struct MMA_Traits -struct MMA_Traits> +template +struct MMA_Traits> { - using ValTypeD = float; - using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = float_e4m3_t; - using ValTypeC = float; + using ValTypeB = int8_t; + using ValTypeC = int32_t; using FrgTypeB = GMMA::smem_desc; @@ -11101,955 +2375,906 @@ struct MMA_Traits -struct MMA_Traits> +template +struct MMA_Traits> { - using ValTypeD = half_t; - using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = float_e4m3_t; - using ValTypeC = half_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; - using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_144,_64>; + using Shape_MNK = Shape<_64,_192,_64>; using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 64>; + using ALayout = GMMA::ALayout_64x64; using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<144, 64>; - using CLayout = GMMA::CLayout_64x144; + using BLayout = GMMA::ABLayout<192, 64>; + using CLayout = GMMA::CLayout_64x192; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template -struct MMA_Traits> +template +struct MMA_Traits> { - using ValTypeD = half_t; - using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = float_e4m3_t; - using ValTypeC = half_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_144,_64>; + using Shape_MNK = Shape<_64,_192,_64>; using ThrID = Layout<_128>; using ALayout = GMMA::ALayout_64x64; using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<144, 64>; - using CLayout = GMMA::CLayout_64x144; + using BLayout = GMMA::ABLayout<192, 64>; + using CLayout = GMMA::CLayout_64x192; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template -struct MMA_Traits> +template +struct MMA_Traits> { - using ValTypeD = float; - using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = float_e4m3_t; - using ValTypeC = float; + using ValTypeB = int8_t; + using ValTypeC = int32_t; - using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_144,_64>; + using Shape_MNK = Shape<_64,_256,_64>; using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 64>; + using ALayout = GMMA::ALayout_64x64; using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<144, 64>; - using CLayout = GMMA::CLayout_64x144; + using BLayout = GMMA::ABLayout<256, 64>; + using CLayout = GMMA::CLayout_64x256; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template -struct MMA_Traits> +template +struct MMA_Traits> { - using ValTypeD = float; - using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = float_e4m3_t; - using ValTypeC = float; + using ValTypeB = int8_t; + using ValTypeC = int32_t; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_144,_64>; + using Shape_MNK = Shape<_64,_256,_64>; using ThrID = Layout<_128>; using ALayout = GMMA::ALayout_64x64; using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<144, 64>; - using CLayout = GMMA::CLayout_64x144; + using BLayout = GMMA::ABLayout<256, 64>; + using CLayout = GMMA::CLayout_64x256; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template -struct MMA_Traits> +template +struct MMA_Traits> { - using ValTypeD = half_t; - using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = float_e4m3_t; - using ValTypeC = half_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_160,_64>; + using Shape_MNK = Shape<_64,_8,_64>; using ThrID = Layout<_128>; using ALayout = GMMA::ABLayout< 64, 64>; using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<160, 64>; - using CLayout = GMMA::CLayout_64x160; + using BLayout = GMMA::ABLayout< 8, 64>; + using CLayout = GMMA::CLayout_64x8; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template -struct MMA_Traits> +template +struct MMA_Traits> { - using ValTypeD = half_t; - using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = float_e4m3_t; - using ValTypeC = half_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_160,_64>; + using Shape_MNK = Shape<_64,_8,_64>; using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x64; + using ALayout = GMMA::ABLayout< 64, 64>; using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<160, 64>; - using CLayout = GMMA::CLayout_64x160; + using BLayout = GMMA::ABLayout< 8, 64>; + using CLayout = GMMA::CLayout_64x8; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template -struct MMA_Traits> +template +struct MMA_Traits> { - using ValTypeD = float; - using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = float_e4m3_t; - using ValTypeC = float; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_160,_64>; + using Shape_MNK = Shape<_64,_16,_64>; using ThrID = Layout<_128>; using ALayout = GMMA::ABLayout< 64, 64>; using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<160, 64>; - using CLayout = GMMA::CLayout_64x160; + using BLayout = GMMA::ABLayout< 16, 64>; + using CLayout = GMMA::CLayout_64x16; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template -struct MMA_Traits> +template +struct MMA_Traits> { - using ValTypeD = float; - using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = float_e4m3_t; - using ValTypeC = float; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_160,_64>; + using Shape_MNK = Shape<_64,_16,_64>; using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x64; + using ALayout = GMMA::ABLayout< 64, 64>; using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<160, 64>; - using CLayout = GMMA::CLayout_64x160; + using BLayout = GMMA::ABLayout< 16, 64>; + using CLayout = GMMA::CLayout_64x16; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template -struct MMA_Traits> +template +struct MMA_Traits> { - using ValTypeD = half_t; - using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = float_e4m3_t; - using ValTypeC = half_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_176,_64>; + using Shape_MNK = Shape<_64,_32,_64>; using ThrID = Layout<_128>; using ALayout = GMMA::ABLayout< 64, 64>; using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<176, 64>; - using CLayout = GMMA::CLayout_64x176; + using BLayout = GMMA::ABLayout< 32, 64>; + using CLayout = GMMA::CLayout_64x32; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template -struct MMA_Traits> +template +struct MMA_Traits> { - using ValTypeD = half_t; - using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = float_e4m3_t; - using ValTypeC = half_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_176,_64>; + using Shape_MNK = Shape<_64,_32,_64>; using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x64; + using ALayout = GMMA::ABLayout< 64, 64>; using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<176, 64>; - using CLayout = GMMA::CLayout_64x176; + using BLayout = GMMA::ABLayout< 32, 64>; + using CLayout = GMMA::CLayout_64x32; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template -struct MMA_Traits> +template +struct MMA_Traits> { - using ValTypeD = float; - using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = float_e4m3_t; - using ValTypeC = float; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_176,_64>; + using Shape_MNK = Shape<_64,_64,_64>; using ThrID = Layout<_128>; using ALayout = GMMA::ABLayout< 64, 64>; using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<176, 64>; - using CLayout = GMMA::CLayout_64x176; + using BLayout = GMMA::ABLayout< 64, 64>; + using CLayout = GMMA::CLayout_64x64; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template -struct MMA_Traits> +template +struct MMA_Traits> { - using ValTypeD = float; - using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = float_e4m3_t; - using ValTypeC = float; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_176,_64>; + using Shape_MNK = Shape<_64,_64,_64>; using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x64; + using ALayout = GMMA::ABLayout< 64, 64>; using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<176, 64>; - using CLayout = GMMA::CLayout_64x176; + using BLayout = GMMA::ABLayout< 64, 64>; + using CLayout = GMMA::CLayout_64x64; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -template -struct MMA_Traits> +template +struct MMA_Traits> { - using ValTypeD = half_t; - using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = float_e4m3_t; - using ValTypeC = half_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_192,_64>; + using Shape_MNK = Shape<_64,_96,_64>; using ThrID = Layout<_128>; using ALayout = GMMA::ABLayout< 64, 64>; using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<192, 64>; - using CLayout = GMMA::CLayout_64x192; + using BLayout = GMMA::ABLayout< 96, 64>; + using CLayout = GMMA::CLayout_64x96; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; //////////////////////////////////////////////////////////////////////////////////////////////////// -template -struct MMA_Traits> +template +struct MMA_Traits> { - using ValTypeD = half_t; - using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = float_e4m3_t; - using ValTypeC = half_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_192,_64>; + using Shape_MNK = Shape<_64,_96,_64>; using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x64; + using ALayout = GMMA::ABLayout< 64, 64>; using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<192, 64>; - using CLayout = GMMA::CLayout_64x192; + using BLayout = GMMA::ABLayout< 96, 64>; + using CLayout = GMMA::CLayout_64x96; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; //////////////////////////////////////////////////////////////////////////////////////////////////// -template -struct MMA_Traits> +template +struct MMA_Traits> { - using ValTypeD = float; - using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = float_e4m3_t; - using ValTypeC = float; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_192,_64>; + using Shape_MNK = Shape<_64,_128,_64>; using ThrID = Layout<_128>; using ALayout = GMMA::ABLayout< 64, 64>; using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<192, 64>; - using CLayout = GMMA::CLayout_64x192; + using BLayout = GMMA::ABLayout<128, 64>; + using CLayout = GMMA::CLayout_64x128; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; //////////////////////////////////////////////////////////////////////////////////////////////////// -template -struct MMA_Traits> +template +struct MMA_Traits> { - using ValTypeD = float; - using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = float_e4m3_t; - using ValTypeC = float; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_192,_64>; + using Shape_MNK = Shape<_64,_128,_64>; using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x64; + using ALayout = GMMA::ABLayout< 64, 64>; using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<192, 64>; - using CLayout = GMMA::CLayout_64x192; + using BLayout = GMMA::ABLayout<128, 64>; + using CLayout = GMMA::CLayout_64x128; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template -struct MMA_Traits> +template +struct MMA_Traits> { - using ValTypeD = half_t; - using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = float_e4m3_t; - using ValTypeC = half_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_208,_64>; + using Shape_MNK = Shape<_64,_192,_64>; using ThrID = Layout<_128>; using ALayout = GMMA::ABLayout< 64, 64>; using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<208, 64>; - using CLayout = GMMA::CLayout_64x208; + using BLayout = GMMA::ABLayout<192, 64>; + using CLayout = GMMA::CLayout_64x192; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template -struct MMA_Traits> +template +struct MMA_Traits> { - using ValTypeD = half_t; - using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = float_e4m3_t; - using ValTypeC = half_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_208,_64>; + using Shape_MNK = Shape<_64,_192,_64>; using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x64; + using ALayout = GMMA::ABLayout< 64, 64>; using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<208, 64>; - using CLayout = GMMA::CLayout_64x208; + using BLayout = GMMA::ABLayout<192, 64>; + using CLayout = GMMA::CLayout_64x192; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template -struct MMA_Traits> +template +struct MMA_Traits> { - using ValTypeD = float; - using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = float_e4m3_t; - using ValTypeC = float; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_208,_64>; + using Shape_MNK = Shape<_64,_256,_64>; using ThrID = Layout<_128>; using ALayout = GMMA::ABLayout< 64, 64>; using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<208, 64>; - using CLayout = GMMA::CLayout_64x208; + using BLayout = GMMA::ABLayout<256, 64>; + using CLayout = GMMA::CLayout_64x256; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template -struct MMA_Traits> +template +struct MMA_Traits> { - using ValTypeD = float; - using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = float_e4m3_t; - using ValTypeC = float; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_208,_64>; + using Shape_MNK = Shape<_64,_256,_64>; using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x64; + using ALayout = GMMA::ABLayout< 64, 64>; using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<208, 64>; - using CLayout = GMMA::CLayout_64x208; + using BLayout = GMMA::ABLayout<256, 64>; + using CLayout = GMMA::CLayout_64x256; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template -struct MMA_Traits> +template +struct MMA_Traits> { - using ValTypeD = half_t; - using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = float_e4m3_t; - using ValTypeC = half_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; - using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_224,_64>; + using Shape_MNK = Shape<_64,_8,_64>; using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 64>; + using ALayout = GMMA::ALayout_64x64; using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<224, 64>; - using CLayout = GMMA::CLayout_64x224; + using BLayout = GMMA::ABLayout< 8, 64>; + using CLayout = GMMA::CLayout_64x8; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template -struct MMA_Traits> +template +struct MMA_Traits> { - using ValTypeD = half_t; - using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = float_e4m3_t; - using ValTypeC = half_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_224,_64>; + using Shape_MNK = Shape<_64,_8,_64>; using ThrID = Layout<_128>; using ALayout = GMMA::ALayout_64x64; using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<224, 64>; - using CLayout = GMMA::CLayout_64x224; + using BLayout = GMMA::ABLayout< 8, 64>; + using CLayout = GMMA::CLayout_64x8; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template -struct MMA_Traits> +template +struct MMA_Traits> { - using ValTypeD = float; - using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = float_e4m3_t; - using ValTypeC = float; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; - using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_224,_64>; + using Shape_MNK = Shape<_64,_16,_64>; using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 64>; + using ALayout = GMMA::ALayout_64x64; using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<224, 64>; - using CLayout = GMMA::CLayout_64x224; + using BLayout = GMMA::ABLayout< 16, 64>; + using CLayout = GMMA::CLayout_64x16; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template -struct MMA_Traits> +template +struct MMA_Traits> { - using ValTypeD = float; - using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = float_e4m3_t; - using ValTypeC = float; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_224,_64>; + using Shape_MNK = Shape<_64,_16,_64>; using ThrID = Layout<_128>; using ALayout = GMMA::ALayout_64x64; using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<224, 64>; - using CLayout = GMMA::CLayout_64x224; + using BLayout = GMMA::ABLayout< 16, 64>; + using CLayout = GMMA::CLayout_64x16; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template -struct MMA_Traits> +template +struct MMA_Traits> { - using ValTypeD = half_t; - using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = float_e4m3_t; - using ValTypeC = half_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; - using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_240,_64>; + using Shape_MNK = Shape<_64,_32,_64>; using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 64>; + using ALayout = GMMA::ALayout_64x64; using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<240, 64>; - using CLayout = GMMA::CLayout_64x240; + using BLayout = GMMA::ABLayout< 32, 64>; + using CLayout = GMMA::CLayout_64x32; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template -struct MMA_Traits> +template +struct MMA_Traits> { - using ValTypeD = half_t; - using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = float_e4m3_t; - using ValTypeC = half_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_240,_64>; + using Shape_MNK = Shape<_64,_32,_64>; using ThrID = Layout<_128>; using ALayout = GMMA::ALayout_64x64; using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<240, 64>; - using CLayout = GMMA::CLayout_64x240; + using BLayout = GMMA::ABLayout< 32, 64>; + using CLayout = GMMA::CLayout_64x32; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template -struct MMA_Traits> +template +struct MMA_Traits> { - using ValTypeD = float; - using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = float_e4m3_t; - using ValTypeC = float; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; - using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_240,_64>; + using Shape_MNK = Shape<_64,_64,_64>; using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 64>; + using ALayout = GMMA::ALayout_64x64; using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<240, 64>; - using CLayout = GMMA::CLayout_64x240; + using BLayout = GMMA::ABLayout< 64, 64>; + using CLayout = GMMA::CLayout_64x64; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template -struct MMA_Traits> +template +struct MMA_Traits> { - using ValTypeD = float; - using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = float_e4m3_t; - using ValTypeC = float; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_240,_64>; + using Shape_MNK = Shape<_64,_64,_64>; using ThrID = Layout<_128>; using ALayout = GMMA::ALayout_64x64; using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<240, 64>; - using CLayout = GMMA::CLayout_64x240; + using BLayout = GMMA::ABLayout< 64, 64>; + using CLayout = GMMA::CLayout_64x64; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -template -struct MMA_Traits> +template +struct MMA_Traits> { - using ValTypeD = half_t; - using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = float_e4m3_t; - using ValTypeC = half_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; - using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_256,_64>; + using Shape_MNK = Shape<_64,_96,_64>; using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 64>; + using ALayout = GMMA::ALayout_64x64; using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<256, 64>; - using CLayout = GMMA::CLayout_64x256; + using BLayout = GMMA::ABLayout< 96, 64>; + using CLayout = GMMA::CLayout_64x96; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; //////////////////////////////////////////////////////////////////////////////////////////////////// -template -struct MMA_Traits> +template +struct MMA_Traits> { - using ValTypeD = half_t; - using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = float_e4m3_t; - using ValTypeC = half_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_256,_64>; + using Shape_MNK = Shape<_64,_96,_64>; using ThrID = Layout<_128>; using ALayout = GMMA::ALayout_64x64; using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<256, 64>; - using CLayout = GMMA::CLayout_64x256; + using BLayout = GMMA::ABLayout< 96, 64>; + using CLayout = GMMA::CLayout_64x96; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; //////////////////////////////////////////////////////////////////////////////////////////////////// -template -struct MMA_Traits> +template +struct MMA_Traits> { - using ValTypeD = float; - using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = float_e4m3_t; - using ValTypeC = float; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; - using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_256,_64>; + using Shape_MNK = Shape<_64,_128,_64>; using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 64>; + using ALayout = GMMA::ALayout_64x64; using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<256, 64>; - using CLayout = GMMA::CLayout_64x256; + using BLayout = GMMA::ABLayout<128, 64>; + using CLayout = GMMA::CLayout_64x128; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; //////////////////////////////////////////////////////////////////////////////////////////////////// -template -struct MMA_Traits> +template +struct MMA_Traits> { - using ValTypeD = float; - using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = float_e4m3_t; - using ValTypeC = float; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_256,_64>; + using Shape_MNK = Shape<_64,_128,_64>; using ThrID = Layout<_128>; using ALayout = GMMA::ALayout_64x64; using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<256, 64>; - using CLayout = GMMA::CLayout_64x256; + using BLayout = GMMA::ABLayout<128, 64>; + using CLayout = GMMA::CLayout_64x128; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; //////////////////////////////////////////////////////////////////////////////////////////////////// -template -struct MMA_Traits> +template +struct MMA_Traits> { - using ValTypeD = half_t; - using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = float_e5m2_t; - using ValTypeC = half_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; - using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_8,_64>; + using Shape_MNK = Shape<_64,_192,_64>; using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 64>; + using ALayout = GMMA::ALayout_64x64; using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout< 8, 64>; - using CLayout = GMMA::CLayout_64x8; + using BLayout = GMMA::ABLayout<192, 64>; + using CLayout = GMMA::CLayout_64x192; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; //////////////////////////////////////////////////////////////////////////////////////////////////// -template -struct MMA_Traits> +template +struct MMA_Traits> { - using ValTypeD = half_t; - using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = float_e5m2_t; - using ValTypeC = half_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_8,_64>; + using Shape_MNK = Shape<_64,_192,_64>; using ThrID = Layout<_128>; using ALayout = GMMA::ALayout_64x64; using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout< 8, 64>; - using CLayout = GMMA::CLayout_64x8; + using BLayout = GMMA::ABLayout<192, 64>; + using CLayout = GMMA::CLayout_64x192; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; //////////////////////////////////////////////////////////////////////////////////////////////////// -template -struct MMA_Traits> +template +struct MMA_Traits> { - using ValTypeD = float; - using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = float_e5m2_t; - using ValTypeC = float; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; - using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_8,_64>; + using Shape_MNK = Shape<_64,_256,_64>; using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 64>; + using ALayout = GMMA::ALayout_64x64; using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout< 8, 64>; - using CLayout = GMMA::CLayout_64x8; + using BLayout = GMMA::ABLayout<256, 64>; + using CLayout = GMMA::CLayout_64x256; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; //////////////////////////////////////////////////////////////////////////////////////////////////// -template -struct MMA_Traits> +template +struct MMA_Traits> { - using ValTypeD = float; - using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = float_e5m2_t; - using ValTypeC = float; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_8,_64>; + using Shape_MNK = Shape<_64,_256,_64>; using ThrID = Layout<_128>; using ALayout = GMMA::ALayout_64x64; using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout< 8, 64>; - using CLayout = GMMA::CLayout_64x8; + using BLayout = GMMA::ABLayout<256, 64>; + using CLayout = GMMA::CLayout_64x256; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; //////////////////////////////////////////////////////////////////////////////////////////////////// -template -struct MMA_Traits> +template +struct MMA_Traits> { - using ValTypeD = half_t; - using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = float_e5m2_t; - using ValTypeC = half_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_16,_64>; + using Shape_MNK = Shape<_64,_8,_64>; using ThrID = Layout<_128>; using ALayout = GMMA::ABLayout< 64, 64>; using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout< 16, 64>; - using CLayout = GMMA::CLayout_64x16; + using BLayout = GMMA::ABLayout< 8, 64>; + using CLayout = GMMA::CLayout_64x8; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; //////////////////////////////////////////////////////////////////////////////////////////////////// -template -struct MMA_Traits> +template +struct MMA_Traits> { - using ValTypeD = half_t; - using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = float_e5m2_t; - using ValTypeC = half_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_16,_64>; + using Shape_MNK = Shape<_64,_8,_64>; using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x64; + using ALayout = GMMA::ABLayout< 64, 64>; using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout< 16, 64>; - using CLayout = GMMA::CLayout_64x16; + using BLayout = GMMA::ABLayout< 8, 64>; + using CLayout = GMMA::CLayout_64x8; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; //////////////////////////////////////////////////////////////////////////////////////////////////// -template -struct MMA_Traits> +template +struct MMA_Traits> { - using ValTypeD = float; - using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = float_e5m2_t; - using ValTypeC = float; + using ValTypeB = int8_t; + using ValTypeC = int32_t; using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; @@ -12066,20 +3291,21 @@ struct MMA_Traits -struct MMA_Traits> +template +struct MMA_Traits> { - using ValTypeD = float; - using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = float_e5m2_t; - using ValTypeC = float; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; using Shape_MNK = Shape<_64,_16,_64>; using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x64; + using ALayout = GMMA::ABLayout< 64, 64>; using ELayout = GMMA::ELayout_64x64; using BLayout = GMMA::ABLayout< 16, 64>; using CLayout = GMMA::CLayout_64x16; @@ -12089,14 +3315,14 @@ struct MMA_Traits -struct MMA_Traits> +template +struct MMA_Traits> { - using ValTypeD = half_t; - using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = float_e5m2_t; - using ValTypeC = half_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; @@ -12113,20 +3339,21 @@ struct MMA_Traits -struct MMA_Traits> +template +struct MMA_Traits> { - using ValTypeD = half_t; - using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = float_e5m2_t; - using ValTypeC = half_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; using Shape_MNK = Shape<_64,_32,_64>; using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x64; + using ALayout = GMMA::ABLayout< 64, 64>; using ELayout = GMMA::ELayout_64x64; using BLayout = GMMA::ABLayout< 32, 64>; using CLayout = GMMA::CLayout_64x32; @@ -12136,955 +3363,910 @@ struct MMA_Traits -struct MMA_Traits> +template +struct MMA_Traits> { - using ValTypeD = float; - using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = float_e5m2_t; - using ValTypeC = float; + using ValTypeB = int8_t; + using ValTypeC = int32_t; using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_32,_64>; + using Shape_MNK = Shape<_64,_64,_64>; using ThrID = Layout<_128>; using ALayout = GMMA::ABLayout< 64, 64>; using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout< 32, 64>; - using CLayout = GMMA::CLayout_64x32; + using BLayout = GMMA::ABLayout< 64, 64>; + using CLayout = GMMA::CLayout_64x64; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; //////////////////////////////////////////////////////////////////////////////////////////////////// -template -struct MMA_Traits> +template +struct MMA_Traits> { - using ValTypeD = float; - using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = float_e5m2_t; - using ValTypeC = float; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_32,_64>; + using Shape_MNK = Shape<_64,_64,_64>; using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x64; + using ALayout = GMMA::ABLayout< 64, 64>; using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout< 32, 64>; - using CLayout = GMMA::CLayout_64x32; + using BLayout = GMMA::ABLayout< 64, 64>; + using CLayout = GMMA::CLayout_64x64; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template -struct MMA_Traits> +template +struct MMA_Traits> { - using ValTypeD = half_t; - using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = float_e5m2_t; - using ValTypeC = half_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_48,_64>; + using Shape_MNK = Shape<_64,_96,_64>; using ThrID = Layout<_128>; using ALayout = GMMA::ABLayout< 64, 64>; using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout< 48, 64>; - using CLayout = GMMA::CLayout_64x48; + using BLayout = GMMA::ABLayout< 96, 64>; + using CLayout = GMMA::CLayout_64x96; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template -struct MMA_Traits> +template +struct MMA_Traits> { - using ValTypeD = half_t; - using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = float_e5m2_t; - using ValTypeC = half_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_48,_64>; + using Shape_MNK = Shape<_64,_96,_64>; using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x64; + using ALayout = GMMA::ABLayout< 64, 64>; using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout< 48, 64>; - using CLayout = GMMA::CLayout_64x48; + using BLayout = GMMA::ABLayout< 96, 64>; + using CLayout = GMMA::CLayout_64x96; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template -struct MMA_Traits> +template +struct MMA_Traits> { - using ValTypeD = float; - using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = float_e5m2_t; - using ValTypeC = float; + using ValTypeB = int8_t; + using ValTypeC = int32_t; using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_48,_64>; + using Shape_MNK = Shape<_64,_128,_64>; using ThrID = Layout<_128>; using ALayout = GMMA::ABLayout< 64, 64>; using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout< 48, 64>; - using CLayout = GMMA::CLayout_64x48; + using BLayout = GMMA::ABLayout<128, 64>; + using CLayout = GMMA::CLayout_64x128; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template -struct MMA_Traits> +template +struct MMA_Traits> { - using ValTypeD = float; - using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = float_e5m2_t; - using ValTypeC = float; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_48,_64>; + using Shape_MNK = Shape<_64,_128,_64>; using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x64; + using ALayout = GMMA::ABLayout< 64, 64>; using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout< 48, 64>; - using CLayout = GMMA::CLayout_64x48; + using BLayout = GMMA::ABLayout<128, 64>; + using CLayout = GMMA::CLayout_64x128; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -template -struct MMA_Traits> +template +struct MMA_Traits> { - using ValTypeD = half_t; - using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = float_e5m2_t; - using ValTypeC = half_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_64,_64>; + using Shape_MNK = Shape<_64,_192,_64>; using ThrID = Layout<_128>; using ALayout = GMMA::ABLayout< 64, 64>; using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout< 64, 64>; - using CLayout = GMMA::CLayout_64x64; + using BLayout = GMMA::ABLayout<192, 64>; + using CLayout = GMMA::CLayout_64x192; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; //////////////////////////////////////////////////////////////////////////////////////////////////// -template -struct MMA_Traits> +template +struct MMA_Traits> { - using ValTypeD = half_t; - using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = float_e5m2_t; - using ValTypeC = half_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_64,_64>; + using Shape_MNK = Shape<_64,_192,_64>; using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x64; + using ALayout = GMMA::ABLayout< 64, 64>; using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout< 64, 64>; - using CLayout = GMMA::CLayout_64x64; + using BLayout = GMMA::ABLayout<192, 64>; + using CLayout = GMMA::CLayout_64x192; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; //////////////////////////////////////////////////////////////////////////////////////////////////// -template -struct MMA_Traits> +template +struct MMA_Traits> { - using ValTypeD = float; - using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = float_e5m2_t; - using ValTypeC = float; + using ValTypeB = int8_t; + using ValTypeC = int32_t; using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_64,_64>; + using Shape_MNK = Shape<_64,_256,_64>; using ThrID = Layout<_128>; using ALayout = GMMA::ABLayout< 64, 64>; using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout< 64, 64>; - using CLayout = GMMA::CLayout_64x64; + using BLayout = GMMA::ABLayout<256, 64>; + using CLayout = GMMA::CLayout_64x256; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; //////////////////////////////////////////////////////////////////////////////////////////////////// -template -struct MMA_Traits> +template +struct MMA_Traits> { - using ValTypeD = float; - using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = float_e5m2_t; - using ValTypeC = float; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_64,_64>; + using Shape_MNK = Shape<_64,_256,_64>; using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x64; + using ALayout = GMMA::ABLayout< 64, 64>; using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout< 64, 64>; - using CLayout = GMMA::CLayout_64x64; + using BLayout = GMMA::ABLayout<256, 64>; + using CLayout = GMMA::CLayout_64x256; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template -struct MMA_Traits> +template +struct MMA_Traits> { - using ValTypeD = half_t; - using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = float_e5m2_t; - using ValTypeC = half_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; - using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_80,_64>; + using Shape_MNK = Shape<_64,_8,_64>; using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 64>; + using ALayout = GMMA::ALayout_64x64; using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout< 80, 64>; - using CLayout = GMMA::CLayout_64x80; + using BLayout = GMMA::ABLayout< 8, 64>; + using CLayout = GMMA::CLayout_64x8; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template -struct MMA_Traits> +template +struct MMA_Traits> { - using ValTypeD = half_t; - using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = float_e5m2_t; - using ValTypeC = half_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_80,_64>; + using Shape_MNK = Shape<_64,_8,_64>; using ThrID = Layout<_128>; using ALayout = GMMA::ALayout_64x64; using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout< 80, 64>; - using CLayout = GMMA::CLayout_64x80; + using BLayout = GMMA::ABLayout< 8, 64>; + using CLayout = GMMA::CLayout_64x8; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template -struct MMA_Traits> +template +struct MMA_Traits> { - using ValTypeD = float; - using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = float_e5m2_t; - using ValTypeC = float; + using ValTypeB = int8_t; + using ValTypeC = int32_t; - using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_80,_64>; + using Shape_MNK = Shape<_64,_16,_64>; using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 64>; + using ALayout = GMMA::ALayout_64x64; using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout< 80, 64>; - using CLayout = GMMA::CLayout_64x80; + using BLayout = GMMA::ABLayout< 16, 64>; + using CLayout = GMMA::CLayout_64x16; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template -struct MMA_Traits> +template +struct MMA_Traits> { - using ValTypeD = float; - using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = float_e5m2_t; - using ValTypeC = float; + using ValTypeB = int8_t; + using ValTypeC = int32_t; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_80,_64>; + using Shape_MNK = Shape<_64,_16,_64>; using ThrID = Layout<_128>; using ALayout = GMMA::ALayout_64x64; using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout< 80, 64>; - using CLayout = GMMA::CLayout_64x80; + using BLayout = GMMA::ABLayout< 16, 64>; + using CLayout = GMMA::CLayout_64x16; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -template -struct MMA_Traits> +template +struct MMA_Traits> { - using ValTypeD = half_t; - using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = float_e5m2_t; - using ValTypeC = half_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; - using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_96,_64>; + using Shape_MNK = Shape<_64,_32,_64>; using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 64>; + using ALayout = GMMA::ALayout_64x64; using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout< 96, 64>; - using CLayout = GMMA::CLayout_64x96; + using BLayout = GMMA::ABLayout< 32, 64>; + using CLayout = GMMA::CLayout_64x32; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; //////////////////////////////////////////////////////////////////////////////////////////////////// -template -struct MMA_Traits> +template +struct MMA_Traits> { - using ValTypeD = half_t; - using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = float_e5m2_t; - using ValTypeC = half_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_96,_64>; + using Shape_MNK = Shape<_64,_32,_64>; using ThrID = Layout<_128>; using ALayout = GMMA::ALayout_64x64; using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout< 96, 64>; - using CLayout = GMMA::CLayout_64x96; + using BLayout = GMMA::ABLayout< 32, 64>; + using CLayout = GMMA::CLayout_64x32; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; //////////////////////////////////////////////////////////////////////////////////////////////////// -template -struct MMA_Traits> +template +struct MMA_Traits> { - using ValTypeD = float; - using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = float_e5m2_t; - using ValTypeC = float; + using ValTypeB = int8_t; + using ValTypeC = int32_t; - using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_96,_64>; + using Shape_MNK = Shape<_64,_64,_64>; using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 64>; + using ALayout = GMMA::ALayout_64x64; using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout< 96, 64>; - using CLayout = GMMA::CLayout_64x96; + using BLayout = GMMA::ABLayout< 64, 64>; + using CLayout = GMMA::CLayout_64x64; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; //////////////////////////////////////////////////////////////////////////////////////////////////// -template -struct MMA_Traits> +template +struct MMA_Traits> { - using ValTypeD = float; - using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = float_e5m2_t; - using ValTypeC = float; + using ValTypeB = int8_t; + using ValTypeC = int32_t; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_96,_64>; + using Shape_MNK = Shape<_64,_64,_64>; using ThrID = Layout<_128>; using ALayout = GMMA::ALayout_64x64; using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout< 96, 64>; - using CLayout = GMMA::CLayout_64x96; + using BLayout = GMMA::ABLayout< 64, 64>; + using CLayout = GMMA::CLayout_64x64; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template -struct MMA_Traits> +template +struct MMA_Traits> { - using ValTypeD = half_t; - using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = float_e5m2_t; - using ValTypeC = half_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; - using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_112,_64>; + using Shape_MNK = Shape<_64,_96,_64>; using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 64>; + using ALayout = GMMA::ALayout_64x64; using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<112, 64>; - using CLayout = GMMA::CLayout_64x112; + using BLayout = GMMA::ABLayout< 96, 64>; + using CLayout = GMMA::CLayout_64x96; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template -struct MMA_Traits> +template +struct MMA_Traits> { - using ValTypeD = half_t; - using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = float_e5m2_t; - using ValTypeC = half_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_112,_64>; + using Shape_MNK = Shape<_64,_96,_64>; using ThrID = Layout<_128>; using ALayout = GMMA::ALayout_64x64; using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<112, 64>; - using CLayout = GMMA::CLayout_64x112; + using BLayout = GMMA::ABLayout< 96, 64>; + using CLayout = GMMA::CLayout_64x96; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template -struct MMA_Traits> +template +struct MMA_Traits> { - using ValTypeD = float; - using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = float_e5m2_t; - using ValTypeC = float; + using ValTypeB = int8_t; + using ValTypeC = int32_t; - using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_112,_64>; + using Shape_MNK = Shape<_64,_128,_64>; using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 64>; + using ALayout = GMMA::ALayout_64x64; using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<112, 64>; - using CLayout = GMMA::CLayout_64x112; + using BLayout = GMMA::ABLayout<128, 64>; + using CLayout = GMMA::CLayout_64x128; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template -struct MMA_Traits> +template +struct MMA_Traits> { - using ValTypeD = float; - using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = float_e5m2_t; - using ValTypeC = float; + using ValTypeB = int8_t; + using ValTypeC = int32_t; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_112,_64>; + using Shape_MNK = Shape<_64,_128,_64>; using ThrID = Layout<_128>; using ALayout = GMMA::ALayout_64x64; using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<112, 64>; - using CLayout = GMMA::CLayout_64x112; + using BLayout = GMMA::ABLayout<128, 64>; + using CLayout = GMMA::CLayout_64x128; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -template -struct MMA_Traits> +template +struct MMA_Traits> { - using ValTypeD = half_t; - using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = float_e5m2_t; - using ValTypeC = half_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; - using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_128,_64>; + using Shape_MNK = Shape<_64,_192,_64>; using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 64>; + using ALayout = GMMA::ALayout_64x64; using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<128, 64>; - using CLayout = GMMA::CLayout_64x128; + using BLayout = GMMA::ABLayout<192, 64>; + using CLayout = GMMA::CLayout_64x192; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; //////////////////////////////////////////////////////////////////////////////////////////////////// -template -struct MMA_Traits> +template +struct MMA_Traits> { - using ValTypeD = half_t; - using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = float_e5m2_t; - using ValTypeC = half_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_128,_64>; + using Shape_MNK = Shape<_64,_192,_64>; using ThrID = Layout<_128>; using ALayout = GMMA::ALayout_64x64; using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<128, 64>; - using CLayout = GMMA::CLayout_64x128; + using BLayout = GMMA::ABLayout<192, 64>; + using CLayout = GMMA::CLayout_64x192; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; //////////////////////////////////////////////////////////////////////////////////////////////////// -template -struct MMA_Traits> +template +struct MMA_Traits> { - using ValTypeD = float; - using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = float_e5m2_t; - using ValTypeC = float; + using ValTypeB = int8_t; + using ValTypeC = int32_t; - using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_128,_64>; + using Shape_MNK = Shape<_64,_256,_64>; using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 64>; + using ALayout = GMMA::ALayout_64x64; using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<128, 64>; - using CLayout = GMMA::CLayout_64x128; + using BLayout = GMMA::ABLayout<256, 64>; + using CLayout = GMMA::CLayout_64x256; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; //////////////////////////////////////////////////////////////////////////////////////////////////// -template -struct MMA_Traits> +template +struct MMA_Traits> { - using ValTypeD = float; - using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = float_e5m2_t; - using ValTypeC = float; + using ValTypeB = int8_t; + using ValTypeC = int32_t; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_128,_64>; + using Shape_MNK = Shape<_64,_256,_64>; using ThrID = Layout<_128>; using ALayout = GMMA::ALayout_64x64; using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<128, 64>; - using CLayout = GMMA::CLayout_64x128; + using BLayout = GMMA::ABLayout<256, 64>; + using CLayout = GMMA::CLayout_64x256; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; //////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template -struct MMA_Traits> + +template +struct MMA_Traits> { - using ValTypeD = half_t; - using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = float_e5m2_t; - using ValTypeC = half_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_144,_64>; + using Shape_MNK = Shape<_64,_8,_64>; using ThrID = Layout<_128>; using ALayout = GMMA::ABLayout< 64, 64>; using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<144, 64>; - using CLayout = GMMA::CLayout_64x144; + using BLayout = GMMA::ABLayout< 8, 64>; + using CLayout = GMMA::CLayout_64x8; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template -struct MMA_Traits> +template +struct MMA_Traits> { - using ValTypeD = half_t; - using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = float_e5m2_t; - using ValTypeC = half_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_144,_64>; + using Shape_MNK = Shape<_64,_8,_64>; using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x64; + using ALayout = GMMA::ABLayout< 64, 64>; using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<144, 64>; - using CLayout = GMMA::CLayout_64x144; + using BLayout = GMMA::ABLayout< 8, 64>; + using CLayout = GMMA::CLayout_64x8; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template -struct MMA_Traits> +template +struct MMA_Traits> { - using ValTypeD = float; - using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = float_e5m2_t; - using ValTypeC = float; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_144,_64>; + using Shape_MNK = Shape<_64,_16,_64>; using ThrID = Layout<_128>; using ALayout = GMMA::ABLayout< 64, 64>; using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<144, 64>; - using CLayout = GMMA::CLayout_64x144; + using BLayout = GMMA::ABLayout< 16, 64>; + using CLayout = GMMA::CLayout_64x16; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template -struct MMA_Traits> +template +struct MMA_Traits> { - using ValTypeD = float; - using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = float_e5m2_t; - using ValTypeC = float; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_144,_64>; + using Shape_MNK = Shape<_64,_16,_64>; using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x64; + using ALayout = GMMA::ABLayout< 64, 64>; using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<144, 64>; - using CLayout = GMMA::CLayout_64x144; + using BLayout = GMMA::ABLayout< 16, 64>; + using CLayout = GMMA::CLayout_64x16; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template -struct MMA_Traits> +template +struct MMA_Traits> { - using ValTypeD = half_t; - using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = float_e5m2_t; - using ValTypeC = half_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_160,_64>; + using Shape_MNK = Shape<_64,_32,_64>; using ThrID = Layout<_128>; using ALayout = GMMA::ABLayout< 64, 64>; using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<160, 64>; - using CLayout = GMMA::CLayout_64x160; + using BLayout = GMMA::ABLayout< 32, 64>; + using CLayout = GMMA::CLayout_64x32; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template -struct MMA_Traits> +template +struct MMA_Traits> { - using ValTypeD = half_t; - using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = float_e5m2_t; - using ValTypeC = half_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_160,_64>; + using Shape_MNK = Shape<_64,_32,_64>; using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x64; + using ALayout = GMMA::ABLayout< 64, 64>; using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<160, 64>; - using CLayout = GMMA::CLayout_64x160; + using BLayout = GMMA::ABLayout< 32, 64>; + using CLayout = GMMA::CLayout_64x32; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template -struct MMA_Traits> +template +struct MMA_Traits> { - using ValTypeD = float; - using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = float_e5m2_t; - using ValTypeC = float; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_160,_64>; + using Shape_MNK = Shape<_64,_64,_64>; using ThrID = Layout<_128>; using ALayout = GMMA::ABLayout< 64, 64>; using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<160, 64>; - using CLayout = GMMA::CLayout_64x160; + using BLayout = GMMA::ABLayout< 64, 64>; + using CLayout = GMMA::CLayout_64x64; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template -struct MMA_Traits> +template +struct MMA_Traits> { - using ValTypeD = float; - using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = float_e5m2_t; - using ValTypeC = float; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_160,_64>; + using Shape_MNK = Shape<_64,_64,_64>; using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x64; + using ALayout = GMMA::ABLayout< 64, 64>; using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<160, 64>; - using CLayout = GMMA::CLayout_64x160; + using BLayout = GMMA::ABLayout< 64, 64>; + using CLayout = GMMA::CLayout_64x64; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template -struct MMA_Traits> +template +struct MMA_Traits> { - using ValTypeD = half_t; - using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = float_e5m2_t; - using ValTypeC = half_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_176,_64>; + using Shape_MNK = Shape<_64,_96,_64>; using ThrID = Layout<_128>; using ALayout = GMMA::ABLayout< 64, 64>; using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<176, 64>; - using CLayout = GMMA::CLayout_64x176; + using BLayout = GMMA::ABLayout< 96, 64>; + using CLayout = GMMA::CLayout_64x96; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template -struct MMA_Traits> +template +struct MMA_Traits> { - using ValTypeD = half_t; - using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = float_e5m2_t; - using ValTypeC = half_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_176,_64>; + using Shape_MNK = Shape<_64,_96,_64>; using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x64; + using ALayout = GMMA::ABLayout< 64, 64>; using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<176, 64>; - using CLayout = GMMA::CLayout_64x176; + using BLayout = GMMA::ABLayout< 96, 64>; + using CLayout = GMMA::CLayout_64x96; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template -struct MMA_Traits> +template +struct MMA_Traits> { - using ValTypeD = float; - using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = float_e5m2_t; - using ValTypeC = float; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_176,_64>; + using Shape_MNK = Shape<_64,_128,_64>; using ThrID = Layout<_128>; using ALayout = GMMA::ABLayout< 64, 64>; using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<176, 64>; - using CLayout = GMMA::CLayout_64x176; + using BLayout = GMMA::ABLayout<128, 64>; + using CLayout = GMMA::CLayout_64x128; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template -struct MMA_Traits> +template +struct MMA_Traits> { - using ValTypeD = float; - using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = float_e5m2_t; - using ValTypeC = float; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_176,_64>; + using Shape_MNK = Shape<_64,_128,_64>; using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x64; + using ALayout = GMMA::ABLayout< 64, 64>; using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<176, 64>; - using CLayout = GMMA::CLayout_64x176; + using BLayout = GMMA::ABLayout<128, 64>; + using CLayout = GMMA::CLayout_64x128; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -template -struct MMA_Traits> +template +struct MMA_Traits> { - using ValTypeD = half_t; - using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = float_e5m2_t; - using ValTypeC = half_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; @@ -13101,20 +4283,21 @@ struct MMA_Traits -struct MMA_Traits> +template +struct MMA_Traits> { - using ValTypeD = half_t; - using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = float_e5m2_t; - using ValTypeC = half_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; using Shape_MNK = Shape<_64,_192,_64>; using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x64; + using ALayout = GMMA::ABLayout< 64, 64>; using ELayout = GMMA::ELayout_64x64; using BLayout = GMMA::ABLayout<192, 64>; using CLayout = GMMA::CLayout_64x192; @@ -13124,421 +4307,390 @@ struct MMA_Traits -struct MMA_Traits> +template +struct MMA_Traits> { - using ValTypeD = float; - using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = float_e5m2_t; - using ValTypeC = float; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_192,_64>; + using Shape_MNK = Shape<_64,_256,_64>; using ThrID = Layout<_128>; using ALayout = GMMA::ABLayout< 64, 64>; using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<192, 64>; - using CLayout = GMMA::CLayout_64x192; + using BLayout = GMMA::ABLayout<256, 64>; + using CLayout = GMMA::CLayout_64x256; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; //////////////////////////////////////////////////////////////////////////////////////////////////// -template -struct MMA_Traits> +template +struct MMA_Traits> { - using ValTypeD = float; - using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = float_e5m2_t; - using ValTypeC = float; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_192,_64>; + using Shape_MNK = Shape<_64,_256,_64>; using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x64; + using ALayout = GMMA::ABLayout< 64, 64>; using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<192, 64>; - using CLayout = GMMA::CLayout_64x192; + using BLayout = GMMA::ABLayout<256, 64>; + using CLayout = GMMA::CLayout_64x256; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template -struct MMA_Traits> +template +struct MMA_Traits> { - using ValTypeD = half_t; - using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = float_e5m2_t; - using ValTypeC = half_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; - using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_208,_64>; + using Shape_MNK = Shape<_64,_8,_64>; using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 64>; + using ALayout = GMMA::ALayout_64x64; using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<208, 64>; - using CLayout = GMMA::CLayout_64x208; + using BLayout = GMMA::ABLayout< 8, 64>; + using CLayout = GMMA::CLayout_64x8; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template -struct MMA_Traits> +template +struct MMA_Traits> { - using ValTypeD = half_t; - using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = float_e5m2_t; - using ValTypeC = half_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_208,_64>; + using Shape_MNK = Shape<_64,_8,_64>; using ThrID = Layout<_128>; using ALayout = GMMA::ALayout_64x64; using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<208, 64>; - using CLayout = GMMA::CLayout_64x208; + using BLayout = GMMA::ABLayout< 8, 64>; + using CLayout = GMMA::CLayout_64x8; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template -struct MMA_Traits> +template +struct MMA_Traits> { - using ValTypeD = float; - using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = float_e5m2_t; - using ValTypeC = float; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; - using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_208,_64>; + using Shape_MNK = Shape<_64,_16,_64>; using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 64>; + using ALayout = GMMA::ALayout_64x64; using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<208, 64>; - using CLayout = GMMA::CLayout_64x208; + using BLayout = GMMA::ABLayout< 16, 64>; + using CLayout = GMMA::CLayout_64x16; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template -struct MMA_Traits> +template +struct MMA_Traits> { - using ValTypeD = float; - using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = float_e5m2_t; - using ValTypeC = float; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_208,_64>; + using Shape_MNK = Shape<_64,_16,_64>; using ThrID = Layout<_128>; using ALayout = GMMA::ALayout_64x64; using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<208, 64>; - using CLayout = GMMA::CLayout_64x208; + using BLayout = GMMA::ABLayout< 16, 64>; + using CLayout = GMMA::CLayout_64x16; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template -struct MMA_Traits> +template +struct MMA_Traits> { - using ValTypeD = half_t; - using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = float_e5m2_t; - using ValTypeC = half_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; - using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_224,_64>; + using Shape_MNK = Shape<_64,_32,_64>; using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 64>; + using ALayout = GMMA::ALayout_64x64; using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<224, 64>; - using CLayout = GMMA::CLayout_64x224; + using BLayout = GMMA::ABLayout< 32, 64>; + using CLayout = GMMA::CLayout_64x32; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template -struct MMA_Traits> +template +struct MMA_Traits> { - using ValTypeD = half_t; - using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = float_e5m2_t; - using ValTypeC = half_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_224,_64>; + using Shape_MNK = Shape<_64,_32,_64>; using ThrID = Layout<_128>; using ALayout = GMMA::ALayout_64x64; using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<224, 64>; - using CLayout = GMMA::CLayout_64x224; + using BLayout = GMMA::ABLayout< 32, 64>; + using CLayout = GMMA::CLayout_64x32; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template -struct MMA_Traits> +template +struct MMA_Traits> { - using ValTypeD = float; - using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = float_e5m2_t; - using ValTypeC = float; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; - using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_224,_64>; + using Shape_MNK = Shape<_64,_64,_64>; using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 64>; + using ALayout = GMMA::ALayout_64x64; using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<224, 64>; - using CLayout = GMMA::CLayout_64x224; + using BLayout = GMMA::ABLayout< 64, 64>; + using CLayout = GMMA::CLayout_64x64; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template -struct MMA_Traits> +template +struct MMA_Traits> { - using ValTypeD = float; - using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = float_e5m2_t; - using ValTypeC = float; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_224,_64>; + using Shape_MNK = Shape<_64,_64,_64>; using ThrID = Layout<_128>; using ALayout = GMMA::ALayout_64x64; using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<224, 64>; - using CLayout = GMMA::CLayout_64x224; + using BLayout = GMMA::ABLayout< 64, 64>; + using CLayout = GMMA::CLayout_64x64; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template -struct MMA_Traits> +template +struct MMA_Traits> { - using ValTypeD = half_t; - using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = float_e5m2_t; - using ValTypeC = half_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; - using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_240,_64>; + using Shape_MNK = Shape<_64,_96,_64>; using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 64>; + using ALayout = GMMA::ALayout_64x64; using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<240, 64>; - using CLayout = GMMA::CLayout_64x240; + using BLayout = GMMA::ABLayout< 96, 64>; + using CLayout = GMMA::CLayout_64x96; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template -struct MMA_Traits> +template +struct MMA_Traits> { - using ValTypeD = half_t; - using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = float_e5m2_t; - using ValTypeC = half_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_240,_64>; + using Shape_MNK = Shape<_64,_96,_64>; using ThrID = Layout<_128>; using ALayout = GMMA::ALayout_64x64; using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<240, 64>; - using CLayout = GMMA::CLayout_64x240; + using BLayout = GMMA::ABLayout< 96, 64>; + using CLayout = GMMA::CLayout_64x96; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template -struct MMA_Traits> +template +struct MMA_Traits> { - using ValTypeD = float; - using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = float_e5m2_t; - using ValTypeC = float; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; - using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_240,_64>; + using Shape_MNK = Shape<_64,_128,_64>; using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 64>; + using ALayout = GMMA::ALayout_64x64; using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<240, 64>; - using CLayout = GMMA::CLayout_64x240; + using BLayout = GMMA::ABLayout<128, 64>; + using CLayout = GMMA::CLayout_64x128; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template -struct MMA_Traits> +template +struct MMA_Traits> { - using ValTypeD = float; - using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = float_e5m2_t; - using ValTypeC = float; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_240,_64>; + using Shape_MNK = Shape<_64,_128,_64>; using ThrID = Layout<_128>; using ALayout = GMMA::ALayout_64x64; using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<240, 64>; - using CLayout = GMMA::CLayout_64x240; + using BLayout = GMMA::ABLayout<128, 64>; + using CLayout = GMMA::CLayout_64x128; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -template -struct MMA_Traits> +template +struct MMA_Traits> { - using ValTypeD = half_t; - using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = float_e5m2_t; - using ValTypeC = half_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; - using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_256,_64>; + using Shape_MNK = Shape<_64,_192,_64>; using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 64>; + using ALayout = GMMA::ALayout_64x64; using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<256, 64>; - using CLayout = GMMA::CLayout_64x256; + using BLayout = GMMA::ABLayout<192, 64>; + using CLayout = GMMA::CLayout_64x192; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; //////////////////////////////////////////////////////////////////////////////////////////////////// -template -struct MMA_Traits> +template +struct MMA_Traits> { - using ValTypeD = half_t; - using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = float_e5m2_t; - using ValTypeC = half_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_256,_64>; + using Shape_MNK = Shape<_64,_192,_64>; using ThrID = Layout<_128>; using ALayout = GMMA::ALayout_64x64; using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<256, 64>; - using CLayout = GMMA::CLayout_64x256; + using BLayout = GMMA::ABLayout<192, 64>; + using CLayout = GMMA::CLayout_64x192; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; //////////////////////////////////////////////////////////////////////////////////////////////////// -template -struct MMA_Traits> +template +struct MMA_Traits> { - using ValTypeD = float; - using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = float_e5m2_t; - using ValTypeC = float; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; - using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; using Shape_MNK = Shape<_64,_256,_64>; using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 64>; + using ALayout = GMMA::ALayout_64x64; using ELayout = GMMA::ELayout_64x64; using BLayout = GMMA::ABLayout<256, 64>; using CLayout = GMMA::CLayout_64x256; @@ -13548,14 +4700,14 @@ struct MMA_Traits -struct MMA_Traits> +template +struct MMA_Traits> { - using ValTypeD = float; - using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = float_e5m2_t; - using ValTypeC = float; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; using FrgTypeB = GMMA::smem_desc; @@ -13572,10 +4724,10 @@ struct MMA_Traits -struct MMA_Traits> +struct MMA_Traits> { using ValTypeD = half_t; - using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeA = sparse_elem<2, float_e4m3_t>; using ValTypeE = sparse_elem<8, uint8_t>; using ValTypeB = float_e4m3_t; using ValTypeC = half_t; @@ -13596,10 +4748,10 @@ struct MMA_Traits -struct MMA_Traits> +struct MMA_Traits> { using ValTypeD = half_t; - using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeA = sparse_elem<2, float_e4m3_t>; using ValTypeE = sparse_elem<8, uint8_t>; using ValTypeB = float_e4m3_t; using ValTypeC = half_t; @@ -13619,10 +4771,10 @@ struct MMA_Traits -struct MMA_Traits> +struct MMA_Traits> { using ValTypeD = float; - using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeA = sparse_elem<2, float_e4m3_t>; using ValTypeE = sparse_elem<8, uint8_t>; using ValTypeB = float_e4m3_t; using ValTypeC = float; @@ -13643,10 +4795,10 @@ struct MMA_Traits -struct MMA_Traits> +struct MMA_Traits> { using ValTypeD = float; - using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeA = sparse_elem<2, float_e4m3_t>; using ValTypeE = sparse_elem<8, uint8_t>; using ValTypeB = float_e4m3_t; using ValTypeC = float; @@ -13666,10 +4818,10 @@ struct MMA_Traits -struct MMA_Traits> +struct MMA_Traits> { using ValTypeD = half_t; - using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeA = sparse_elem<2, float_e4m3_t>; using ValTypeE = sparse_elem<8, uint8_t>; using ValTypeB = float_e4m3_t; using ValTypeC = half_t; @@ -13690,10 +4842,10 @@ struct MMA_Traits -struct MMA_Traits> +struct MMA_Traits> { using ValTypeD = half_t; - using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeA = sparse_elem<2, float_e4m3_t>; using ValTypeE = sparse_elem<8, uint8_t>; using ValTypeB = float_e4m3_t; using ValTypeC = half_t; @@ -13713,10 +4865,10 @@ struct MMA_Traits -struct MMA_Traits> +struct MMA_Traits> { using ValTypeD = float; - using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeA = sparse_elem<2, float_e4m3_t>; using ValTypeE = sparse_elem<8, uint8_t>; using ValTypeB = float_e4m3_t; using ValTypeC = float; @@ -13737,10 +4889,10 @@ struct MMA_Traits -struct MMA_Traits> +struct MMA_Traits> { using ValTypeD = float; - using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeA = sparse_elem<2, float_e4m3_t>; using ValTypeE = sparse_elem<8, uint8_t>; using ValTypeB = float_e4m3_t; using ValTypeC = float; @@ -13760,10 +4912,10 @@ struct MMA_Traits -struct MMA_Traits> +struct MMA_Traits> { using ValTypeD = half_t; - using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeA = sparse_elem<2, float_e4m3_t>; using ValTypeE = sparse_elem<8, uint8_t>; using ValTypeB = float_e4m3_t; using ValTypeC = half_t; @@ -13784,10 +4936,10 @@ struct MMA_Traits -struct MMA_Traits> +struct MMA_Traits> { using ValTypeD = half_t; - using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeA = sparse_elem<2, float_e4m3_t>; using ValTypeE = sparse_elem<8, uint8_t>; using ValTypeB = float_e4m3_t; using ValTypeC = half_t; @@ -13807,10 +4959,10 @@ struct MMA_Traits -struct MMA_Traits> +struct MMA_Traits> { using ValTypeD = float; - using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeA = sparse_elem<2, float_e4m3_t>; using ValTypeE = sparse_elem<8, uint8_t>; using ValTypeB = float_e4m3_t; using ValTypeC = float; @@ -13831,10 +4983,10 @@ struct MMA_Traits -struct MMA_Traits> +struct MMA_Traits> { using ValTypeD = float; - using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeA = sparse_elem<2, float_e4m3_t>; using ValTypeE = sparse_elem<8, uint8_t>; using ValTypeB = float_e4m3_t; using ValTypeC = float; @@ -13853,113 +5005,11 @@ struct MMA_Traits -struct MMA_Traits> -{ - using ValTypeD = half_t; - using ValTypeA = sparse_elem<2, float_e5m2_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = float_e4m3_t; - using ValTypeC = half_t; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_48,_64>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 64>; - using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout< 48, 64>; - using CLayout = GMMA::CLayout_64x48; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template -struct MMA_Traits> -{ - using ValTypeD = half_t; - using ValTypeA = sparse_elem<2, float_e5m2_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = float_e4m3_t; - using ValTypeC = half_t; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_48,_64>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x64; - using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout< 48, 64>; - using CLayout = GMMA::CLayout_64x48; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template -struct MMA_Traits> -{ - using ValTypeD = float; - using ValTypeA = sparse_elem<2, float_e5m2_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = float_e4m3_t; - using ValTypeC = float; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_48,_64>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 64>; - using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout< 48, 64>; - using CLayout = GMMA::CLayout_64x48; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template -struct MMA_Traits> -{ - using ValTypeD = float; - using ValTypeA = sparse_elem<2, float_e5m2_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = float_e4m3_t; - using ValTypeC = float; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_48,_64>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x64; - using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout< 48, 64>; - using CLayout = GMMA::CLayout_64x48; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - template -struct MMA_Traits> +struct MMA_Traits> { using ValTypeD = half_t; - using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeA = sparse_elem<2, float_e4m3_t>; using ValTypeE = sparse_elem<8, uint8_t>; using ValTypeB = float_e4m3_t; using ValTypeC = half_t; @@ -13980,10 +5030,10 @@ struct MMA_Traits -struct MMA_Traits> +struct MMA_Traits> { using ValTypeD = half_t; - using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeA = sparse_elem<2, float_e4m3_t>; using ValTypeE = sparse_elem<8, uint8_t>; using ValTypeB = float_e4m3_t; using ValTypeC = half_t; @@ -14003,10 +5053,10 @@ struct MMA_Traits -struct MMA_Traits> +struct MMA_Traits> { using ValTypeD = float; - using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeA = sparse_elem<2, float_e4m3_t>; using ValTypeE = sparse_elem<8, uint8_t>; using ValTypeB = float_e4m3_t; using ValTypeC = float; @@ -14027,10 +5077,10 @@ struct MMA_Traits -struct MMA_Traits> +struct MMA_Traits> { using ValTypeD = float; - using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeA = sparse_elem<2, float_e4m3_t>; using ValTypeE = sparse_elem<8, uint8_t>; using ValTypeB = float_e4m3_t; using ValTypeC = float; @@ -14049,12 +5099,11 @@ struct MMA_Traits -struct MMA_Traits> +struct MMA_Traits> { using ValTypeD = half_t; - using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeA = sparse_elem<2, float_e4m3_t>; using ValTypeE = sparse_elem<8, uint8_t>; using ValTypeB = float_e4m3_t; using ValTypeC = half_t; @@ -14062,50 +5111,46 @@ struct MMA_Traits; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_80,_64>; + using Shape_MNK = Shape<_64,_96,_64>; using ThrID = Layout<_128>; using ALayout = GMMA::ABLayout< 64, 64>; using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout< 80, 64>; - using CLayout = GMMA::CLayout_64x80; + using BLayout = GMMA::ABLayout< 96, 64>; + using CLayout = GMMA::CLayout_64x96; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) template -struct MMA_Traits> +struct MMA_Traits> { using ValTypeD = half_t; - using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeA = sparse_elem<2, float_e4m3_t>; using ValTypeE = sparse_elem<8, uint8_t>; using ValTypeB = float_e4m3_t; using ValTypeC = half_t; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_80,_64>; + using Shape_MNK = Shape<_64,_96,_64>; using ThrID = Layout<_128>; using ALayout = GMMA::ALayout_64x64; using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout< 80, 64>; - using CLayout = GMMA::CLayout_64x80; + using BLayout = GMMA::ABLayout< 96, 64>; + using CLayout = GMMA::CLayout_64x96; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) template -struct MMA_Traits> +struct MMA_Traits> { using ValTypeD = float; - using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeA = sparse_elem<2, float_e4m3_t>; using ValTypeE = sparse_elem<8, uint8_t>; using ValTypeB = float_e4m3_t; using ValTypeC = float; @@ -14113,49 +5158,46 @@ struct MMA_Traits; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_80,_64>; + using Shape_MNK = Shape<_64,_96,_64>; using ThrID = Layout<_128>; using ALayout = GMMA::ABLayout< 64, 64>; using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout< 80, 64>; - using CLayout = GMMA::CLayout_64x80; + using BLayout = GMMA::ABLayout< 96, 64>; + using CLayout = GMMA::CLayout_64x96; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) template -struct MMA_Traits> +struct MMA_Traits> { using ValTypeD = float; - using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeA = sparse_elem<2, float_e4m3_t>; using ValTypeE = sparse_elem<8, uint8_t>; using ValTypeB = float_e4m3_t; using ValTypeC = float; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_80,_64>; + using Shape_MNK = Shape<_64,_96,_64>; using ThrID = Layout<_128>; using ALayout = GMMA::ALayout_64x64; using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout< 80, 64>; - using CLayout = GMMA::CLayout_64x80; + using BLayout = GMMA::ABLayout< 96, 64>; + using CLayout = GMMA::CLayout_64x96; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// template -struct MMA_Traits> +struct MMA_Traits> { using ValTypeD = half_t; - using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeA = sparse_elem<2, float_e4m3_t>; using ValTypeE = sparse_elem<8, uint8_t>; using ValTypeB = float_e4m3_t; using ValTypeC = half_t; @@ -14163,12 +5205,12 @@ struct MMA_Traits; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_96,_64>; + using Shape_MNK = Shape<_64,_128,_64>; using ThrID = Layout<_128>; using ALayout = GMMA::ABLayout< 64, 64>; using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout< 96, 64>; - using CLayout = GMMA::CLayout_64x96; + using BLayout = GMMA::ABLayout<128, 64>; + using CLayout = GMMA::CLayout_64x128; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; @@ -14176,22 +5218,22 @@ struct MMA_Traits -struct MMA_Traits> +struct MMA_Traits> { using ValTypeD = half_t; - using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeA = sparse_elem<2, float_e4m3_t>; using ValTypeE = sparse_elem<8, uint8_t>; using ValTypeB = float_e4m3_t; using ValTypeC = half_t; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_96,_64>; + using Shape_MNK = Shape<_64,_128,_64>; using ThrID = Layout<_128>; using ALayout = GMMA::ALayout_64x64; using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout< 96, 64>; - using CLayout = GMMA::CLayout_64x96; + using BLayout = GMMA::ABLayout<128, 64>; + using CLayout = GMMA::CLayout_64x128; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; @@ -14199,10 +5241,10 @@ struct MMA_Traits -struct MMA_Traits> +struct MMA_Traits> { using ValTypeD = float; - using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeA = sparse_elem<2, float_e4m3_t>; using ValTypeE = sparse_elem<8, uint8_t>; using ValTypeB = float_e4m3_t; using ValTypeC = float; @@ -14210,12 +5252,12 @@ struct MMA_Traits; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_96,_64>; + using Shape_MNK = Shape<_64,_128,_64>; using ThrID = Layout<_128>; using ALayout = GMMA::ABLayout< 64, 64>; using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout< 96, 64>; - using CLayout = GMMA::CLayout_64x96; + using BLayout = GMMA::ABLayout<128, 64>; + using CLayout = GMMA::CLayout_64x128; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; @@ -14223,34 +5265,33 @@ struct MMA_Traits -struct MMA_Traits> +struct MMA_Traits> { using ValTypeD = float; - using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeA = sparse_elem<2, float_e4m3_t>; using ValTypeE = sparse_elem<8, uint8_t>; using ValTypeB = float_e4m3_t; using ValTypeC = float; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_96,_64>; + using Shape_MNK = Shape<_64,_128,_64>; using ThrID = Layout<_128>; using ALayout = GMMA::ALayout_64x64; using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout< 96, 64>; - using CLayout = GMMA::CLayout_64x96; + using BLayout = GMMA::ABLayout<128, 64>; + using CLayout = GMMA::CLayout_64x128; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) template -struct MMA_Traits> +struct MMA_Traits> { using ValTypeD = half_t; - using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeA = sparse_elem<2, float_e4m3_t>; using ValTypeE = sparse_elem<8, uint8_t>; using ValTypeB = float_e4m3_t; using ValTypeC = half_t; @@ -14258,50 +5299,46 @@ struct MMA_Traits; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_112,_64>; + using Shape_MNK = Shape<_64,_192,_64>; using ThrID = Layout<_128>; using ALayout = GMMA::ABLayout< 64, 64>; using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<112, 64>; - using CLayout = GMMA::CLayout_64x112; + using BLayout = GMMA::ABLayout<192, 64>; + using CLayout = GMMA::CLayout_64x192; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) template -struct MMA_Traits> +struct MMA_Traits> { using ValTypeD = half_t; - using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeA = sparse_elem<2, float_e4m3_t>; using ValTypeE = sparse_elem<8, uint8_t>; using ValTypeB = float_e4m3_t; using ValTypeC = half_t; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_112,_64>; + using Shape_MNK = Shape<_64,_192,_64>; using ThrID = Layout<_128>; using ALayout = GMMA::ALayout_64x64; using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<112, 64>; - using CLayout = GMMA::CLayout_64x112; + using BLayout = GMMA::ABLayout<192, 64>; + using CLayout = GMMA::CLayout_64x192; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) template -struct MMA_Traits> +struct MMA_Traits> { using ValTypeD = float; - using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeA = sparse_elem<2, float_e4m3_t>; using ValTypeE = sparse_elem<8, uint8_t>; using ValTypeB = float_e4m3_t; using ValTypeC = float; @@ -14309,49 +5346,46 @@ struct MMA_Traits; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_112,_64>; + using Shape_MNK = Shape<_64,_192,_64>; using ThrID = Layout<_128>; using ALayout = GMMA::ABLayout< 64, 64>; using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<112, 64>; - using CLayout = GMMA::CLayout_64x112; + using BLayout = GMMA::ABLayout<192, 64>; + using CLayout = GMMA::CLayout_64x192; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) template -struct MMA_Traits> +struct MMA_Traits> { using ValTypeD = float; - using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeA = sparse_elem<2, float_e4m3_t>; using ValTypeE = sparse_elem<8, uint8_t>; using ValTypeB = float_e4m3_t; using ValTypeC = float; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_112,_64>; + using Shape_MNK = Shape<_64,_192,_64>; using ThrID = Layout<_128>; using ALayout = GMMA::ALayout_64x64; using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<112, 64>; - using CLayout = GMMA::CLayout_64x112; + using BLayout = GMMA::ABLayout<192, 64>; + using CLayout = GMMA::CLayout_64x192; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// template -struct MMA_Traits> +struct MMA_Traits> { using ValTypeD = half_t; - using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeA = sparse_elem<2, float_e4m3_t>; using ValTypeE = sparse_elem<8, uint8_t>; using ValTypeB = float_e4m3_t; using ValTypeC = half_t; @@ -14359,12 +5393,12 @@ struct MMA_Traits; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_128,_64>; + using Shape_MNK = Shape<_64,_256,_64>; using ThrID = Layout<_128>; using ALayout = GMMA::ABLayout< 64, 64>; using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<128, 64>; - using CLayout = GMMA::CLayout_64x128; + using BLayout = GMMA::ABLayout<256, 64>; + using CLayout = GMMA::CLayout_64x256; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; @@ -14372,22 +5406,22 @@ struct MMA_Traits -struct MMA_Traits> +struct MMA_Traits> { using ValTypeD = half_t; - using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeA = sparse_elem<2, float_e4m3_t>; using ValTypeE = sparse_elem<8, uint8_t>; using ValTypeB = float_e4m3_t; using ValTypeC = half_t; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_128,_64>; + using Shape_MNK = Shape<_64,_256,_64>; using ThrID = Layout<_128>; using ALayout = GMMA::ALayout_64x64; using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<128, 64>; - using CLayout = GMMA::CLayout_64x128; + using BLayout = GMMA::ABLayout<256, 64>; + using CLayout = GMMA::CLayout_64x256; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; @@ -14395,10 +5429,10 @@ struct MMA_Traits -struct MMA_Traits> +struct MMA_Traits> { using ValTypeD = float; - using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeA = sparse_elem<2, float_e4m3_t>; using ValTypeE = sparse_elem<8, uint8_t>; using ValTypeB = float_e4m3_t; using ValTypeC = float; @@ -14406,12 +5440,12 @@ struct MMA_Traits; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_128,_64>; + using Shape_MNK = Shape<_64,_256,_64>; using ThrID = Layout<_128>; using ALayout = GMMA::ABLayout< 64, 64>; using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<128, 64>; - using CLayout = GMMA::CLayout_64x128; + using BLayout = GMMA::ABLayout<256, 64>; + using CLayout = GMMA::CLayout_64x256; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; @@ -14419,352 +5453,328 @@ struct MMA_Traits -struct MMA_Traits> +struct MMA_Traits> { using ValTypeD = float; - using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeA = sparse_elem<2, float_e4m3_t>; using ValTypeE = sparse_elem<8, uint8_t>; using ValTypeB = float_e4m3_t; using ValTypeC = float; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_128,_64>; + using Shape_MNK = Shape<_64,_256,_64>; using ThrID = Layout<_128>; using ALayout = GMMA::ALayout_64x64; using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<128, 64>; - using CLayout = GMMA::CLayout_64x128; + using BLayout = GMMA::ABLayout<256, 64>; + using CLayout = GMMA::CLayout_64x256; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) template -struct MMA_Traits> +struct MMA_Traits> { using ValTypeD = half_t; - using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeA = sparse_elem<2, float_e4m3_t>; using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = float_e4m3_t; + using ValTypeB = float_e5m2_t; using ValTypeC = half_t; using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_144,_64>; + using Shape_MNK = Shape<_64,_8,_64>; using ThrID = Layout<_128>; using ALayout = GMMA::ABLayout< 64, 64>; using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<144, 64>; - using CLayout = GMMA::CLayout_64x144; + using BLayout = GMMA::ABLayout< 8, 64>; + using CLayout = GMMA::CLayout_64x8; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) template -struct MMA_Traits> +struct MMA_Traits> { using ValTypeD = half_t; - using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeA = sparse_elem<2, float_e4m3_t>; using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = float_e4m3_t; + using ValTypeB = float_e5m2_t; using ValTypeC = half_t; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_144,_64>; + using Shape_MNK = Shape<_64,_8,_64>; using ThrID = Layout<_128>; using ALayout = GMMA::ALayout_64x64; using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<144, 64>; - using CLayout = GMMA::CLayout_64x144; + using BLayout = GMMA::ABLayout< 8, 64>; + using CLayout = GMMA::CLayout_64x8; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) template -struct MMA_Traits> +struct MMA_Traits> { using ValTypeD = float; - using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeA = sparse_elem<2, float_e4m3_t>; using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = float_e4m3_t; + using ValTypeB = float_e5m2_t; using ValTypeC = float; using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_144,_64>; + using Shape_MNK = Shape<_64,_8,_64>; using ThrID = Layout<_128>; using ALayout = GMMA::ABLayout< 64, 64>; using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<144, 64>; - using CLayout = GMMA::CLayout_64x144; + using BLayout = GMMA::ABLayout< 8, 64>; + using CLayout = GMMA::CLayout_64x8; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) template -struct MMA_Traits> +struct MMA_Traits> { using ValTypeD = float; - using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeA = sparse_elem<2, float_e4m3_t>; using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = float_e4m3_t; + using ValTypeB = float_e5m2_t; using ValTypeC = float; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_144,_64>; + using Shape_MNK = Shape<_64,_8,_64>; using ThrID = Layout<_128>; using ALayout = GMMA::ALayout_64x64; using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<144, 64>; - using CLayout = GMMA::CLayout_64x144; + using BLayout = GMMA::ABLayout< 8, 64>; + using CLayout = GMMA::CLayout_64x8; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) template -struct MMA_Traits> +struct MMA_Traits> { using ValTypeD = half_t; - using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeA = sparse_elem<2, float_e4m3_t>; using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = float_e4m3_t; + using ValTypeB = float_e5m2_t; using ValTypeC = half_t; using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_160,_64>; + using Shape_MNK = Shape<_64,_16,_64>; using ThrID = Layout<_128>; using ALayout = GMMA::ABLayout< 64, 64>; using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<160, 64>; - using CLayout = GMMA::CLayout_64x160; + using BLayout = GMMA::ABLayout< 16, 64>; + using CLayout = GMMA::CLayout_64x16; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) template -struct MMA_Traits> +struct MMA_Traits> { using ValTypeD = half_t; - using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeA = sparse_elem<2, float_e4m3_t>; using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = float_e4m3_t; + using ValTypeB = float_e5m2_t; using ValTypeC = half_t; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_160,_64>; + using Shape_MNK = Shape<_64,_16,_64>; using ThrID = Layout<_128>; using ALayout = GMMA::ALayout_64x64; using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<160, 64>; - using CLayout = GMMA::CLayout_64x160; + using BLayout = GMMA::ABLayout< 16, 64>; + using CLayout = GMMA::CLayout_64x16; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) template -struct MMA_Traits> +struct MMA_Traits> { using ValTypeD = float; - using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeA = sparse_elem<2, float_e4m3_t>; using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = float_e4m3_t; + using ValTypeB = float_e5m2_t; using ValTypeC = float; using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_160,_64>; + using Shape_MNK = Shape<_64,_16,_64>; using ThrID = Layout<_128>; using ALayout = GMMA::ABLayout< 64, 64>; using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<160, 64>; - using CLayout = GMMA::CLayout_64x160; + using BLayout = GMMA::ABLayout< 16, 64>; + using CLayout = GMMA::CLayout_64x16; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) template -struct MMA_Traits> +struct MMA_Traits> { using ValTypeD = float; - using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeA = sparse_elem<2, float_e4m3_t>; using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = float_e4m3_t; + using ValTypeB = float_e5m2_t; using ValTypeC = float; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_160,_64>; + using Shape_MNK = Shape<_64,_16,_64>; using ThrID = Layout<_128>; using ALayout = GMMA::ALayout_64x64; using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<160, 64>; - using CLayout = GMMA::CLayout_64x160; + using BLayout = GMMA::ABLayout< 16, 64>; + using CLayout = GMMA::CLayout_64x16; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) template -struct MMA_Traits> +struct MMA_Traits> { using ValTypeD = half_t; - using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeA = sparse_elem<2, float_e4m3_t>; using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = float_e4m3_t; + using ValTypeB = float_e5m2_t; using ValTypeC = half_t; using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_176,_64>; + using Shape_MNK = Shape<_64,_32,_64>; using ThrID = Layout<_128>; using ALayout = GMMA::ABLayout< 64, 64>; using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<176, 64>; - using CLayout = GMMA::CLayout_64x176; + using BLayout = GMMA::ABLayout< 32, 64>; + using CLayout = GMMA::CLayout_64x32; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) template -struct MMA_Traits> +struct MMA_Traits> { using ValTypeD = half_t; - using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeA = sparse_elem<2, float_e4m3_t>; using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = float_e4m3_t; + using ValTypeB = float_e5m2_t; using ValTypeC = half_t; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_176,_64>; + using Shape_MNK = Shape<_64,_32,_64>; using ThrID = Layout<_128>; using ALayout = GMMA::ALayout_64x64; using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<176, 64>; - using CLayout = GMMA::CLayout_64x176; + using BLayout = GMMA::ABLayout< 32, 64>; + using CLayout = GMMA::CLayout_64x32; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) template -struct MMA_Traits> +struct MMA_Traits> { using ValTypeD = float; - using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeA = sparse_elem<2, float_e4m3_t>; using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = float_e4m3_t; + using ValTypeB = float_e5m2_t; using ValTypeC = float; using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_176,_64>; + using Shape_MNK = Shape<_64,_32,_64>; using ThrID = Layout<_128>; using ALayout = GMMA::ABLayout< 64, 64>; using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<176, 64>; - using CLayout = GMMA::CLayout_64x176; + using BLayout = GMMA::ABLayout< 32, 64>; + using CLayout = GMMA::CLayout_64x32; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) template -struct MMA_Traits> +struct MMA_Traits> { using ValTypeD = float; - using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeA = sparse_elem<2, float_e4m3_t>; using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = float_e4m3_t; + using ValTypeB = float_e5m2_t; using ValTypeC = float; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_176,_64>; + using Shape_MNK = Shape<_64,_32,_64>; using ThrID = Layout<_128>; using ALayout = GMMA::ALayout_64x64; using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<176, 64>; - using CLayout = GMMA::CLayout_64x176; + using BLayout = GMMA::ABLayout< 32, 64>; + using CLayout = GMMA::CLayout_64x32; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// template -struct MMA_Traits> +struct MMA_Traits> { using ValTypeD = half_t; - using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeA = sparse_elem<2, float_e4m3_t>; using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = float_e4m3_t; + using ValTypeB = float_e5m2_t; using ValTypeC = half_t; using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_192,_64>; + using Shape_MNK = Shape<_64,_64,_64>; using ThrID = Layout<_128>; using ALayout = GMMA::ABLayout< 64, 64>; using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<192, 64>; - using CLayout = GMMA::CLayout_64x192; + using BLayout = GMMA::ABLayout< 64, 64>; + using CLayout = GMMA::CLayout_64x64; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; @@ -14772,22 +5782,22 @@ struct MMA_Traits -struct MMA_Traits> +struct MMA_Traits> { using ValTypeD = half_t; - using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeA = sparse_elem<2, float_e4m3_t>; using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = float_e4m3_t; + using ValTypeB = float_e5m2_t; using ValTypeC = half_t; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_192,_64>; + using Shape_MNK = Shape<_64,_64,_64>; using ThrID = Layout<_128>; using ALayout = GMMA::ALayout_64x64; using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<192, 64>; - using CLayout = GMMA::CLayout_64x192; + using BLayout = GMMA::ABLayout< 64, 64>; + using CLayout = GMMA::CLayout_64x64; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; @@ -14795,23 +5805,23 @@ struct MMA_Traits -struct MMA_Traits> +struct MMA_Traits> { using ValTypeD = float; - using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeA = sparse_elem<2, float_e4m3_t>; using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = float_e4m3_t; + using ValTypeB = float_e5m2_t; using ValTypeC = float; using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_192,_64>; + using Shape_MNK = Shape<_64,_64,_64>; using ThrID = Layout<_128>; using ALayout = GMMA::ABLayout< 64, 64>; using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<192, 64>; - using CLayout = GMMA::CLayout_64x192; + using BLayout = GMMA::ABLayout< 64, 64>; + using CLayout = GMMA::CLayout_64x64; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; @@ -14819,341 +5829,317 @@ struct MMA_Traits -struct MMA_Traits> +struct MMA_Traits> { using ValTypeD = float; - using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeA = sparse_elem<2, float_e4m3_t>; using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = float_e4m3_t; + using ValTypeB = float_e5m2_t; using ValTypeC = float; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_192,_64>; + using Shape_MNK = Shape<_64,_64,_64>; using ThrID = Layout<_128>; using ALayout = GMMA::ALayout_64x64; using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<192, 64>; - using CLayout = GMMA::CLayout_64x192; + using BLayout = GMMA::ABLayout< 64, 64>; + using CLayout = GMMA::CLayout_64x64; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) template -struct MMA_Traits> +struct MMA_Traits> { using ValTypeD = half_t; - using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeA = sparse_elem<2, float_e4m3_t>; using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = float_e4m3_t; + using ValTypeB = float_e5m2_t; using ValTypeC = half_t; using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_208,_64>; + using Shape_MNK = Shape<_64,_96,_64>; using ThrID = Layout<_128>; using ALayout = GMMA::ABLayout< 64, 64>; using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<208, 64>; - using CLayout = GMMA::CLayout_64x208; + using BLayout = GMMA::ABLayout< 96, 64>; + using CLayout = GMMA::CLayout_64x96; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) template -struct MMA_Traits> +struct MMA_Traits> { using ValTypeD = half_t; - using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeA = sparse_elem<2, float_e4m3_t>; using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = float_e4m3_t; + using ValTypeB = float_e5m2_t; using ValTypeC = half_t; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_208,_64>; + using Shape_MNK = Shape<_64,_96,_64>; using ThrID = Layout<_128>; using ALayout = GMMA::ALayout_64x64; using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<208, 64>; - using CLayout = GMMA::CLayout_64x208; + using BLayout = GMMA::ABLayout< 96, 64>; + using CLayout = GMMA::CLayout_64x96; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) template -struct MMA_Traits> +struct MMA_Traits> { using ValTypeD = float; - using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeA = sparse_elem<2, float_e4m3_t>; using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = float_e4m3_t; + using ValTypeB = float_e5m2_t; using ValTypeC = float; using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_208,_64>; + using Shape_MNK = Shape<_64,_96,_64>; using ThrID = Layout<_128>; using ALayout = GMMA::ABLayout< 64, 64>; using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<208, 64>; - using CLayout = GMMA::CLayout_64x208; + using BLayout = GMMA::ABLayout< 96, 64>; + using CLayout = GMMA::CLayout_64x96; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) template -struct MMA_Traits> +struct MMA_Traits> { using ValTypeD = float; - using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeA = sparse_elem<2, float_e4m3_t>; using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = float_e4m3_t; + using ValTypeB = float_e5m2_t; using ValTypeC = float; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_208,_64>; + using Shape_MNK = Shape<_64,_96,_64>; using ThrID = Layout<_128>; using ALayout = GMMA::ALayout_64x64; using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<208, 64>; - using CLayout = GMMA::CLayout_64x208; + using BLayout = GMMA::ABLayout< 96, 64>; + using CLayout = GMMA::CLayout_64x96; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) template -struct MMA_Traits> +struct MMA_Traits> { using ValTypeD = half_t; - using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeA = sparse_elem<2, float_e4m3_t>; using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = float_e4m3_t; + using ValTypeB = float_e5m2_t; using ValTypeC = half_t; using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_224,_64>; + using Shape_MNK = Shape<_64,_128,_64>; using ThrID = Layout<_128>; using ALayout = GMMA::ABLayout< 64, 64>; using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<224, 64>; - using CLayout = GMMA::CLayout_64x224; + using BLayout = GMMA::ABLayout<128, 64>; + using CLayout = GMMA::CLayout_64x128; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) template -struct MMA_Traits> +struct MMA_Traits> { using ValTypeD = half_t; - using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeA = sparse_elem<2, float_e4m3_t>; using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = float_e4m3_t; + using ValTypeB = float_e5m2_t; using ValTypeC = half_t; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_224,_64>; + using Shape_MNK = Shape<_64,_128,_64>; using ThrID = Layout<_128>; using ALayout = GMMA::ALayout_64x64; using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<224, 64>; - using CLayout = GMMA::CLayout_64x224; + using BLayout = GMMA::ABLayout<128, 64>; + using CLayout = GMMA::CLayout_64x128; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) template -struct MMA_Traits> +struct MMA_Traits> { using ValTypeD = float; - using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeA = sparse_elem<2, float_e4m3_t>; using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = float_e4m3_t; + using ValTypeB = float_e5m2_t; using ValTypeC = float; using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_224,_64>; + using Shape_MNK = Shape<_64,_128,_64>; using ThrID = Layout<_128>; using ALayout = GMMA::ABLayout< 64, 64>; using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<224, 64>; - using CLayout = GMMA::CLayout_64x224; + using BLayout = GMMA::ABLayout<128, 64>; + using CLayout = GMMA::CLayout_64x128; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) template -struct MMA_Traits> +struct MMA_Traits> { using ValTypeD = float; - using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeA = sparse_elem<2, float_e4m3_t>; using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = float_e4m3_t; + using ValTypeB = float_e5m2_t; using ValTypeC = float; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_224,_64>; + using Shape_MNK = Shape<_64,_128,_64>; using ThrID = Layout<_128>; using ALayout = GMMA::ALayout_64x64; using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<224, 64>; - using CLayout = GMMA::CLayout_64x224; + using BLayout = GMMA::ABLayout<128, 64>; + using CLayout = GMMA::CLayout_64x128; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) template -struct MMA_Traits> +struct MMA_Traits> { using ValTypeD = half_t; - using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeA = sparse_elem<2, float_e4m3_t>; using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = float_e4m3_t; + using ValTypeB = float_e5m2_t; using ValTypeC = half_t; using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_240,_64>; + using Shape_MNK = Shape<_64,_192,_64>; using ThrID = Layout<_128>; using ALayout = GMMA::ABLayout< 64, 64>; using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<240, 64>; - using CLayout = GMMA::CLayout_64x240; + using BLayout = GMMA::ABLayout<192, 64>; + using CLayout = GMMA::CLayout_64x192; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) template -struct MMA_Traits> +struct MMA_Traits> { using ValTypeD = half_t; - using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeA = sparse_elem<2, float_e4m3_t>; using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = float_e4m3_t; + using ValTypeB = float_e5m2_t; using ValTypeC = half_t; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_240,_64>; + using Shape_MNK = Shape<_64,_192,_64>; using ThrID = Layout<_128>; using ALayout = GMMA::ALayout_64x64; using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<240, 64>; - using CLayout = GMMA::CLayout_64x240; + using BLayout = GMMA::ABLayout<192, 64>; + using CLayout = GMMA::CLayout_64x192; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) template -struct MMA_Traits> +struct MMA_Traits> { using ValTypeD = float; - using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeA = sparse_elem<2, float_e4m3_t>; using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = float_e4m3_t; + using ValTypeB = float_e5m2_t; using ValTypeC = float; using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_240,_64>; + using Shape_MNK = Shape<_64,_192,_64>; using ThrID = Layout<_128>; using ALayout = GMMA::ABLayout< 64, 64>; using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<240, 64>; - using CLayout = GMMA::CLayout_64x240; + using BLayout = GMMA::ABLayout<192, 64>; + using CLayout = GMMA::CLayout_64x192; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) template -struct MMA_Traits> +struct MMA_Traits> { using ValTypeD = float; - using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeA = sparse_elem<2, float_e4m3_t>; using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = float_e4m3_t; + using ValTypeB = float_e5m2_t; using ValTypeC = float; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_240,_64>; + using Shape_MNK = Shape<_64,_192,_64>; using ThrID = Layout<_128>; using ALayout = GMMA::ALayout_64x64; using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<240, 64>; - using CLayout = GMMA::CLayout_64x240; + using BLayout = GMMA::ABLayout<192, 64>; + using CLayout = GMMA::CLayout_64x192; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// template -struct MMA_Traits> +struct MMA_Traits> { using ValTypeD = half_t; - using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeA = sparse_elem<2, float_e4m3_t>; using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = float_e4m3_t; + using ValTypeB = float_e5m2_t; using ValTypeC = half_t; using FrgTypeA = GMMA::smem_desc; @@ -15172,12 +6158,12 @@ struct MMA_Traits -struct MMA_Traits> +struct MMA_Traits> { using ValTypeD = half_t; - using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeA = sparse_elem<2, float_e4m3_t>; using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = float_e4m3_t; + using ValTypeB = float_e5m2_t; using ValTypeC = half_t; using FrgTypeB = GMMA::smem_desc; @@ -15195,12 +6181,12 @@ struct MMA_Traits -struct MMA_Traits> +struct MMA_Traits> { using ValTypeD = float; - using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeA = sparse_elem<2, float_e4m3_t>; using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = float_e4m3_t; + using ValTypeB = float_e5m2_t; using ValTypeC = float; using FrgTypeA = GMMA::smem_desc; @@ -15219,12 +6205,12 @@ struct MMA_Traits -struct MMA_Traits> +struct MMA_Traits> { using ValTypeD = float; - using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeA = sparse_elem<2, float_e4m3_t>; using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = float_e4m3_t; + using ValTypeB = float_e5m2_t; using ValTypeC = float; using FrgTypeB = GMMA::smem_desc; @@ -15242,12 +6228,12 @@ struct MMA_Traits -struct MMA_Traits> +struct MMA_Traits> { using ValTypeD = half_t; using ValTypeA = sparse_elem<2, float_e5m2_t>; using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = float_e5m2_t; + using ValTypeB = float_e4m3_t; using ValTypeC = half_t; using FrgTypeA = GMMA::smem_desc; @@ -15266,12 +6252,12 @@ struct MMA_Traits -struct MMA_Traits> +struct MMA_Traits> { using ValTypeD = half_t; using ValTypeA = sparse_elem<2, float_e5m2_t>; using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = float_e5m2_t; + using ValTypeB = float_e4m3_t; using ValTypeC = half_t; using FrgTypeB = GMMA::smem_desc; @@ -15289,12 +6275,12 @@ struct MMA_Traits -struct MMA_Traits> +struct MMA_Traits> { using ValTypeD = float; using ValTypeA = sparse_elem<2, float_e5m2_t>; using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = float_e5m2_t; + using ValTypeB = float_e4m3_t; using ValTypeC = float; using FrgTypeA = GMMA::smem_desc; @@ -15313,12 +6299,12 @@ struct MMA_Traits -struct MMA_Traits> +struct MMA_Traits> { using ValTypeD = float; using ValTypeA = sparse_elem<2, float_e5m2_t>; using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = float_e5m2_t; + using ValTypeB = float_e4m3_t; using ValTypeC = float; using FrgTypeB = GMMA::smem_desc; @@ -15336,12 +6322,12 @@ struct MMA_Traits -struct MMA_Traits> +struct MMA_Traits> { using ValTypeD = half_t; using ValTypeA = sparse_elem<2, float_e5m2_t>; using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = float_e5m2_t; + using ValTypeB = float_e4m3_t; using ValTypeC = half_t; using FrgTypeA = GMMA::smem_desc; @@ -15360,12 +6346,12 @@ struct MMA_Traits -struct MMA_Traits> +struct MMA_Traits> { using ValTypeD = half_t; using ValTypeA = sparse_elem<2, float_e5m2_t>; using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = float_e5m2_t; + using ValTypeB = float_e4m3_t; using ValTypeC = half_t; using FrgTypeB = GMMA::smem_desc; @@ -15383,12 +6369,12 @@ struct MMA_Traits -struct MMA_Traits> +struct MMA_Traits> { using ValTypeD = float; using ValTypeA = sparse_elem<2, float_e5m2_t>; using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = float_e5m2_t; + using ValTypeB = float_e4m3_t; using ValTypeC = float; using FrgTypeA = GMMA::smem_desc; @@ -15407,12 +6393,12 @@ struct MMA_Traits -struct MMA_Traits> +struct MMA_Traits> { using ValTypeD = float; using ValTypeA = sparse_elem<2, float_e5m2_t>; using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = float_e5m2_t; + using ValTypeB = float_e4m3_t; using ValTypeC = float; using FrgTypeB = GMMA::smem_desc; @@ -15430,12 +6416,12 @@ struct MMA_Traits -struct MMA_Traits> +struct MMA_Traits> { using ValTypeD = half_t; using ValTypeA = sparse_elem<2, float_e5m2_t>; using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = float_e5m2_t; + using ValTypeB = float_e4m3_t; using ValTypeC = half_t; using FrgTypeA = GMMA::smem_desc; @@ -15454,12 +6440,12 @@ struct MMA_Traits -struct MMA_Traits> +struct MMA_Traits> { using ValTypeD = half_t; using ValTypeA = sparse_elem<2, float_e5m2_t>; using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = float_e5m2_t; + using ValTypeB = float_e4m3_t; using ValTypeC = half_t; using FrgTypeB = GMMA::smem_desc; @@ -15477,12 +6463,12 @@ struct MMA_Traits -struct MMA_Traits> +struct MMA_Traits> { using ValTypeD = float; using ValTypeA = sparse_elem<2, float_e5m2_t>; using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = float_e5m2_t; + using ValTypeB = float_e4m3_t; using ValTypeC = float; using FrgTypeA = GMMA::smem_desc; @@ -15501,12 +6487,12 @@ struct MMA_Traits -struct MMA_Traits> +struct MMA_Traits> { using ValTypeD = float; using ValTypeA = sparse_elem<2, float_e5m2_t>; using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = float_e5m2_t; + using ValTypeB = float_e4m3_t; using ValTypeC = float; using FrgTypeB = GMMA::smem_desc; @@ -15523,115 +6509,13 @@ struct MMA_Traits -struct MMA_Traits> -{ - using ValTypeD = half_t; - using ValTypeA = sparse_elem<2, float_e5m2_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = float_e5m2_t; - using ValTypeC = half_t; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_48,_64>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 64>; - using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout< 48, 64>; - using CLayout = GMMA::CLayout_64x48; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template -struct MMA_Traits> -{ - using ValTypeD = half_t; - using ValTypeA = sparse_elem<2, float_e5m2_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = float_e5m2_t; - using ValTypeC = half_t; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_48,_64>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x64; - using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout< 48, 64>; - using CLayout = GMMA::CLayout_64x48; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template -struct MMA_Traits> -{ - using ValTypeD = float; - using ValTypeA = sparse_elem<2, float_e5m2_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = float_e5m2_t; - using ValTypeC = float; - - using FrgTypeA = GMMA::smem_desc; - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_48,_64>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ABLayout< 64, 64>; - using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout< 48, 64>; - using CLayout = GMMA::CLayout_64x48; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) -template -struct MMA_Traits> -{ - using ValTypeD = float; - using ValTypeA = sparse_elem<2, float_e5m2_t>; - using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = float_e5m2_t; - using ValTypeC = float; - - using FrgTypeB = GMMA::smem_desc; - - using Shape_MNK = Shape<_64,_48,_64>; - using ThrID = Layout<_128>; - using ALayout = GMMA::ALayout_64x64; - using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout< 48, 64>; - using CLayout = GMMA::CLayout_64x48; - - GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - template -struct MMA_Traits> +struct MMA_Traits> { using ValTypeD = half_t; using ValTypeA = sparse_elem<2, float_e5m2_t>; using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = float_e5m2_t; + using ValTypeB = float_e4m3_t; using ValTypeC = half_t; using FrgTypeA = GMMA::smem_desc; @@ -15650,12 +6534,12 @@ struct MMA_Traits -struct MMA_Traits> +struct MMA_Traits> { using ValTypeD = half_t; using ValTypeA = sparse_elem<2, float_e5m2_t>; using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = float_e5m2_t; + using ValTypeB = float_e4m3_t; using ValTypeC = half_t; using FrgTypeB = GMMA::smem_desc; @@ -15673,12 +6557,12 @@ struct MMA_Traits -struct MMA_Traits> +struct MMA_Traits> { using ValTypeD = float; using ValTypeA = sparse_elem<2, float_e5m2_t>; using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = float_e5m2_t; + using ValTypeB = float_e4m3_t; using ValTypeC = float; using FrgTypeA = GMMA::smem_desc; @@ -15697,12 +6581,12 @@ struct MMA_Traits -struct MMA_Traits> +struct MMA_Traits> { using ValTypeD = float; using ValTypeA = sparse_elem<2, float_e5m2_t>; using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = float_e5m2_t; + using ValTypeB = float_e4m3_t; using ValTypeC = float; using FrgTypeB = GMMA::smem_desc; @@ -15719,126 +6603,118 @@ struct MMA_Traits -struct MMA_Traits> +struct MMA_Traits> { using ValTypeD = half_t; using ValTypeA = sparse_elem<2, float_e5m2_t>; using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = float_e5m2_t; + using ValTypeB = float_e4m3_t; using ValTypeC = half_t; using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_80,_64>; + using Shape_MNK = Shape<_64,_96,_64>; using ThrID = Layout<_128>; using ALayout = GMMA::ABLayout< 64, 64>; using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout< 80, 64>; - using CLayout = GMMA::CLayout_64x80; + using BLayout = GMMA::ABLayout< 96, 64>; + using CLayout = GMMA::CLayout_64x96; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) template -struct MMA_Traits> +struct MMA_Traits> { using ValTypeD = half_t; using ValTypeA = sparse_elem<2, float_e5m2_t>; using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = float_e5m2_t; + using ValTypeB = float_e4m3_t; using ValTypeC = half_t; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_80,_64>; + using Shape_MNK = Shape<_64,_96,_64>; using ThrID = Layout<_128>; using ALayout = GMMA::ALayout_64x64; using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout< 80, 64>; - using CLayout = GMMA::CLayout_64x80; + using BLayout = GMMA::ABLayout< 96, 64>; + using CLayout = GMMA::CLayout_64x96; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) template -struct MMA_Traits> +struct MMA_Traits> { using ValTypeD = float; using ValTypeA = sparse_elem<2, float_e5m2_t>; using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = float_e5m2_t; + using ValTypeB = float_e4m3_t; using ValTypeC = float; using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_80,_64>; + using Shape_MNK = Shape<_64,_96,_64>; using ThrID = Layout<_128>; using ALayout = GMMA::ABLayout< 64, 64>; using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout< 80, 64>; - using CLayout = GMMA::CLayout_64x80; + using BLayout = GMMA::ABLayout< 96, 64>; + using CLayout = GMMA::CLayout_64x96; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) template -struct MMA_Traits> +struct MMA_Traits> { using ValTypeD = float; using ValTypeA = sparse_elem<2, float_e5m2_t>; using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = float_e5m2_t; + using ValTypeB = float_e4m3_t; using ValTypeC = float; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_80,_64>; + using Shape_MNK = Shape<_64,_96,_64>; using ThrID = Layout<_128>; using ALayout = GMMA::ALayout_64x64; using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout< 80, 64>; - using CLayout = GMMA::CLayout_64x80; + using BLayout = GMMA::ABLayout< 96, 64>; + using CLayout = GMMA::CLayout_64x96; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// template -struct MMA_Traits> +struct MMA_Traits> { using ValTypeD = half_t; using ValTypeA = sparse_elem<2, float_e5m2_t>; using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = float_e5m2_t; + using ValTypeB = float_e4m3_t; using ValTypeC = half_t; using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_96,_64>; + using Shape_MNK = Shape<_64,_128,_64>; using ThrID = Layout<_128>; using ALayout = GMMA::ABLayout< 64, 64>; - using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout< 96, 64>; - using CLayout = GMMA::CLayout_64x96; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<128, 64>; + using CLayout = GMMA::CLayout_64x128; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; @@ -15846,22 +6722,22 @@ struct MMA_Traits -struct MMA_Traits> +struct MMA_Traits> { using ValTypeD = half_t; using ValTypeA = sparse_elem<2, float_e5m2_t>; using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = float_e5m2_t; + using ValTypeB = float_e4m3_t; using ValTypeC = half_t; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_96,_64>; + using Shape_MNK = Shape<_64,_128,_64>; using ThrID = Layout<_128>; using ALayout = GMMA::ALayout_64x64; using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout< 96, 64>; - using CLayout = GMMA::CLayout_64x96; + using BLayout = GMMA::ABLayout<128, 64>; + using CLayout = GMMA::CLayout_64x128; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; @@ -15869,23 +6745,23 @@ struct MMA_Traits -struct MMA_Traits> +struct MMA_Traits> { using ValTypeD = float; using ValTypeA = sparse_elem<2, float_e5m2_t>; using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = float_e5m2_t; + using ValTypeB = float_e4m3_t; using ValTypeC = float; using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_96,_64>; + using Shape_MNK = Shape<_64,_128,_64>; using ThrID = Layout<_128>; using ALayout = GMMA::ABLayout< 64, 64>; using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout< 96, 64>; - using CLayout = GMMA::CLayout_64x96; + using BLayout = GMMA::ABLayout<128, 64>; + using CLayout = GMMA::CLayout_64x128; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; @@ -15893,148 +6769,140 @@ struct MMA_Traits -struct MMA_Traits> +struct MMA_Traits> { using ValTypeD = float; using ValTypeA = sparse_elem<2, float_e5m2_t>; using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = float_e5m2_t; + using ValTypeB = float_e4m3_t; using ValTypeC = float; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_96,_64>; + using Shape_MNK = Shape<_64,_128,_64>; using ThrID = Layout<_128>; using ALayout = GMMA::ALayout_64x64; using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout< 96, 64>; - using CLayout = GMMA::CLayout_64x96; + using BLayout = GMMA::ABLayout<128, 64>; + using CLayout = GMMA::CLayout_64x128; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) template -struct MMA_Traits> +struct MMA_Traits> { using ValTypeD = half_t; using ValTypeA = sparse_elem<2, float_e5m2_t>; using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = float_e5m2_t; + using ValTypeB = float_e4m3_t; using ValTypeC = half_t; using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_112,_64>; + using Shape_MNK = Shape<_64,_192,_64>; using ThrID = Layout<_128>; using ALayout = GMMA::ABLayout< 64, 64>; using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<112, 64>; - using CLayout = GMMA::CLayout_64x112; + using BLayout = GMMA::ABLayout<192, 64>; + using CLayout = GMMA::CLayout_64x192; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) template -struct MMA_Traits> +struct MMA_Traits> { using ValTypeD = half_t; using ValTypeA = sparse_elem<2, float_e5m2_t>; using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = float_e5m2_t; + using ValTypeB = float_e4m3_t; using ValTypeC = half_t; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_112,_64>; + using Shape_MNK = Shape<_64,_192,_64>; using ThrID = Layout<_128>; using ALayout = GMMA::ALayout_64x64; using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<112, 64>; - using CLayout = GMMA::CLayout_64x112; + using BLayout = GMMA::ABLayout<192, 64>; + using CLayout = GMMA::CLayout_64x192; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) template -struct MMA_Traits> +struct MMA_Traits> { using ValTypeD = float; using ValTypeA = sparse_elem<2, float_e5m2_t>; using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = float_e5m2_t; + using ValTypeB = float_e4m3_t; using ValTypeC = float; using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_112,_64>; + using Shape_MNK = Shape<_64,_192,_64>; using ThrID = Layout<_128>; using ALayout = GMMA::ABLayout< 64, 64>; using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<112, 64>; - using CLayout = GMMA::CLayout_64x112; + using BLayout = GMMA::ABLayout<192, 64>; + using CLayout = GMMA::CLayout_64x192; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) template -struct MMA_Traits> +struct MMA_Traits> { using ValTypeD = float; using ValTypeA = sparse_elem<2, float_e5m2_t>; using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = float_e5m2_t; + using ValTypeB = float_e4m3_t; using ValTypeC = float; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_112,_64>; + using Shape_MNK = Shape<_64,_192,_64>; using ThrID = Layout<_128>; using ALayout = GMMA::ALayout_64x64; using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<112, 64>; - using CLayout = GMMA::CLayout_64x112; + using BLayout = GMMA::ABLayout<192, 64>; + using CLayout = GMMA::CLayout_64x192; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// template -struct MMA_Traits> +struct MMA_Traits> { using ValTypeD = half_t; using ValTypeA = sparse_elem<2, float_e5m2_t>; using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = float_e5m2_t; + using ValTypeB = float_e4m3_t; using ValTypeC = half_t; using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_128,_64>; + using Shape_MNK = Shape<_64,_256,_64>; using ThrID = Layout<_128>; using ALayout = GMMA::ABLayout< 64, 64>; using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<128, 64>; - using CLayout = GMMA::CLayout_64x128; + using BLayout = GMMA::ABLayout<256, 64>; + using CLayout = GMMA::CLayout_64x256; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; @@ -16042,22 +6910,22 @@ struct MMA_Traits -struct MMA_Traits> +struct MMA_Traits> { using ValTypeD = half_t; using ValTypeA = sparse_elem<2, float_e5m2_t>; using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = float_e5m2_t; + using ValTypeB = float_e4m3_t; using ValTypeC = half_t; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_128,_64>; + using Shape_MNK = Shape<_64,_256,_64>; using ThrID = Layout<_128>; using ALayout = GMMA::ALayout_64x64; using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<128, 64>; - using CLayout = GMMA::CLayout_64x128; + using BLayout = GMMA::ABLayout<256, 64>; + using CLayout = GMMA::CLayout_64x256; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; @@ -16065,23 +6933,23 @@ struct MMA_Traits -struct MMA_Traits> +struct MMA_Traits> { using ValTypeD = float; using ValTypeA = sparse_elem<2, float_e5m2_t>; using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = float_e5m2_t; + using ValTypeB = float_e4m3_t; using ValTypeC = float; using FrgTypeA = GMMA::smem_desc; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_128,_64>; + using Shape_MNK = Shape<_64,_256,_64>; using ThrID = Layout<_128>; using ALayout = GMMA::ABLayout< 64, 64>; using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<128, 64>; - using CLayout = GMMA::CLayout_64x128; + using BLayout = GMMA::ABLayout<256, 64>; + using CLayout = GMMA::CLayout_64x256; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; @@ -16089,31 +6957,30 @@ struct MMA_Traits -struct MMA_Traits> +struct MMA_Traits> { using ValTypeD = float; using ValTypeA = sparse_elem<2, float_e5m2_t>; using ValTypeE = sparse_elem<8, uint8_t>; - using ValTypeB = float_e5m2_t; + using ValTypeB = float_e4m3_t; using ValTypeC = float; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_128,_64>; + using Shape_MNK = Shape<_64,_256,_64>; using ThrID = Layout<_128>; using ALayout = GMMA::ALayout_64x64; using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<128, 64>; - using CLayout = GMMA::CLayout_64x128; + using BLayout = GMMA::ABLayout<256, 64>; + using CLayout = GMMA::CLayout_64x256; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) template -struct MMA_Traits> +struct MMA_Traits> { using ValTypeD = half_t; using ValTypeA = sparse_elem<2, float_e5m2_t>; @@ -16124,22 +6991,20 @@ struct MMA_Traits; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_144,_64>; + using Shape_MNK = Shape<_64,_8,_64>; using ThrID = Layout<_128>; using ALayout = GMMA::ABLayout< 64, 64>; using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<144, 64>; - using CLayout = GMMA::CLayout_64x144; + using BLayout = GMMA::ABLayout< 8, 64>; + using CLayout = GMMA::CLayout_64x8; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) template -struct MMA_Traits> +struct MMA_Traits> { using ValTypeD = half_t; using ValTypeA = sparse_elem<2, float_e5m2_t>; @@ -16149,22 +7014,20 @@ struct MMA_Traits; - using Shape_MNK = Shape<_64,_144,_64>; + using Shape_MNK = Shape<_64,_8,_64>; using ThrID = Layout<_128>; using ALayout = GMMA::ALayout_64x64; using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<144, 64>; - using CLayout = GMMA::CLayout_64x144; + using BLayout = GMMA::ABLayout< 8, 64>; + using CLayout = GMMA::CLayout_64x8; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) template -struct MMA_Traits> +struct MMA_Traits> { using ValTypeD = float; using ValTypeA = sparse_elem<2, float_e5m2_t>; @@ -16175,22 +7038,20 @@ struct MMA_Traits; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_144,_64>; + using Shape_MNK = Shape<_64,_8,_64>; using ThrID = Layout<_128>; using ALayout = GMMA::ABLayout< 64, 64>; using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<144, 64>; - using CLayout = GMMA::CLayout_64x144; + using BLayout = GMMA::ABLayout< 8, 64>; + using CLayout = GMMA::CLayout_64x8; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) template -struct MMA_Traits> +struct MMA_Traits> { using ValTypeD = float; using ValTypeA = sparse_elem<2, float_e5m2_t>; @@ -16200,22 +7061,20 @@ struct MMA_Traits; - using Shape_MNK = Shape<_64,_144,_64>; + using Shape_MNK = Shape<_64,_8,_64>; using ThrID = Layout<_128>; using ALayout = GMMA::ALayout_64x64; using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<144, 64>; - using CLayout = GMMA::CLayout_64x144; + using BLayout = GMMA::ABLayout< 8, 64>; + using CLayout = GMMA::CLayout_64x8; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) template -struct MMA_Traits> +struct MMA_Traits> { using ValTypeD = half_t; using ValTypeA = sparse_elem<2, float_e5m2_t>; @@ -16226,22 +7085,20 @@ struct MMA_Traits; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_160,_64>; + using Shape_MNK = Shape<_64,_16,_64>; using ThrID = Layout<_128>; using ALayout = GMMA::ABLayout< 64, 64>; using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<160, 64>; - using CLayout = GMMA::CLayout_64x160; + using BLayout = GMMA::ABLayout< 16, 64>; + using CLayout = GMMA::CLayout_64x16; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) template -struct MMA_Traits> +struct MMA_Traits> { using ValTypeD = half_t; using ValTypeA = sparse_elem<2, float_e5m2_t>; @@ -16251,22 +7108,20 @@ struct MMA_Traits; - using Shape_MNK = Shape<_64,_160,_64>; + using Shape_MNK = Shape<_64,_16,_64>; using ThrID = Layout<_128>; using ALayout = GMMA::ALayout_64x64; using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<160, 64>; - using CLayout = GMMA::CLayout_64x160; + using BLayout = GMMA::ABLayout< 16, 64>; + using CLayout = GMMA::CLayout_64x16; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) template -struct MMA_Traits> +struct MMA_Traits> { using ValTypeD = float; using ValTypeA = sparse_elem<2, float_e5m2_t>; @@ -16277,22 +7132,20 @@ struct MMA_Traits; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_160,_64>; + using Shape_MNK = Shape<_64,_16,_64>; using ThrID = Layout<_128>; using ALayout = GMMA::ABLayout< 64, 64>; using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<160, 64>; - using CLayout = GMMA::CLayout_64x160; + using BLayout = GMMA::ABLayout< 16, 64>; + using CLayout = GMMA::CLayout_64x16; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) template -struct MMA_Traits> +struct MMA_Traits> { using ValTypeD = float; using ValTypeA = sparse_elem<2, float_e5m2_t>; @@ -16302,22 +7155,20 @@ struct MMA_Traits; - using Shape_MNK = Shape<_64,_160,_64>; + using Shape_MNK = Shape<_64,_16,_64>; using ThrID = Layout<_128>; using ALayout = GMMA::ALayout_64x64; using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<160, 64>; - using CLayout = GMMA::CLayout_64x160; + using BLayout = GMMA::ABLayout< 16, 64>; + using CLayout = GMMA::CLayout_64x16; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) template -struct MMA_Traits> +struct MMA_Traits> { using ValTypeD = half_t; using ValTypeA = sparse_elem<2, float_e5m2_t>; @@ -16328,22 +7179,20 @@ struct MMA_Traits; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_176,_64>; + using Shape_MNK = Shape<_64,_32,_64>; using ThrID = Layout<_128>; using ALayout = GMMA::ABLayout< 64, 64>; using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<176, 64>; - using CLayout = GMMA::CLayout_64x176; + using BLayout = GMMA::ABLayout< 32, 64>; + using CLayout = GMMA::CLayout_64x32; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) template -struct MMA_Traits> +struct MMA_Traits> { using ValTypeD = half_t; using ValTypeA = sparse_elem<2, float_e5m2_t>; @@ -16353,22 +7202,20 @@ struct MMA_Traits; - using Shape_MNK = Shape<_64,_176,_64>; + using Shape_MNK = Shape<_64,_32,_64>; using ThrID = Layout<_128>; using ALayout = GMMA::ALayout_64x64; using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<176, 64>; - using CLayout = GMMA::CLayout_64x176; + using BLayout = GMMA::ABLayout< 32, 64>; + using CLayout = GMMA::CLayout_64x32; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) template -struct MMA_Traits> +struct MMA_Traits> { using ValTypeD = float; using ValTypeA = sparse_elem<2, float_e5m2_t>; @@ -16379,22 +7226,20 @@ struct MMA_Traits; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_176,_64>; + using Shape_MNK = Shape<_64,_32,_64>; using ThrID = Layout<_128>; using ALayout = GMMA::ABLayout< 64, 64>; using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<176, 64>; - using CLayout = GMMA::CLayout_64x176; + using BLayout = GMMA::ABLayout< 32, 64>; + using CLayout = GMMA::CLayout_64x32; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) template -struct MMA_Traits> +struct MMA_Traits> { using ValTypeD = float; using ValTypeA = sparse_elem<2, float_e5m2_t>; @@ -16404,21 +7249,20 @@ struct MMA_Traits; - using Shape_MNK = Shape<_64,_176,_64>; + using Shape_MNK = Shape<_64,_32,_64>; using ThrID = Layout<_128>; using ALayout = GMMA::ALayout_64x64; using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<176, 64>; - using CLayout = GMMA::CLayout_64x176; + using BLayout = GMMA::ABLayout< 32, 64>; + using CLayout = GMMA::CLayout_64x32; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// template -struct MMA_Traits> +struct MMA_Traits> { using ValTypeD = half_t; using ValTypeA = sparse_elem<2, float_e5m2_t>; @@ -16429,12 +7273,12 @@ struct MMA_Traits; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_192,_64>; + using Shape_MNK = Shape<_64,_64,_64>; using ThrID = Layout<_128>; using ALayout = GMMA::ABLayout< 64, 64>; using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<192, 64>; - using CLayout = GMMA::CLayout_64x192; + using BLayout = GMMA::ABLayout< 64, 64>; + using CLayout = GMMA::CLayout_64x64; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; @@ -16442,7 +7286,7 @@ struct MMA_Traits -struct MMA_Traits> +struct MMA_Traits> { using ValTypeD = half_t; using ValTypeA = sparse_elem<2, float_e5m2_t>; @@ -16452,12 +7296,12 @@ struct MMA_Traits; - using Shape_MNK = Shape<_64,_192,_64>; + using Shape_MNK = Shape<_64,_64,_64>; using ThrID = Layout<_128>; using ALayout = GMMA::ALayout_64x64; using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<192, 64>; - using CLayout = GMMA::CLayout_64x192; + using BLayout = GMMA::ABLayout< 64, 64>; + using CLayout = GMMA::CLayout_64x64; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; @@ -16465,7 +7309,7 @@ struct MMA_Traits -struct MMA_Traits> +struct MMA_Traits> { using ValTypeD = float; using ValTypeA = sparse_elem<2, float_e5m2_t>; @@ -16476,12 +7320,12 @@ struct MMA_Traits; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_192,_64>; + using Shape_MNK = Shape<_64,_64,_64>; using ThrID = Layout<_128>; using ALayout = GMMA::ABLayout< 64, 64>; using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<192, 64>; - using CLayout = GMMA::CLayout_64x192; + using BLayout = GMMA::ABLayout< 64, 64>; + using CLayout = GMMA::CLayout_64x64; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; @@ -16489,7 +7333,7 @@ struct MMA_Traits -struct MMA_Traits> +struct MMA_Traits> { using ValTypeD = float; using ValTypeA = sparse_elem<2, float_e5m2_t>; @@ -16499,21 +7343,20 @@ struct MMA_Traits; - using Shape_MNK = Shape<_64,_192,_64>; + using Shape_MNK = Shape<_64,_64,_64>; using ThrID = Layout<_128>; using ALayout = GMMA::ALayout_64x64; using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<192, 64>; - using CLayout = GMMA::CLayout_64x192; + using BLayout = GMMA::ABLayout< 64, 64>; + using CLayout = GMMA::CLayout_64x64; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) template -struct MMA_Traits> +struct MMA_Traits> { using ValTypeD = half_t; using ValTypeA = sparse_elem<2, float_e5m2_t>; @@ -16524,22 +7367,20 @@ struct MMA_Traits; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_208,_64>; + using Shape_MNK = Shape<_64,_96,_64>; using ThrID = Layout<_128>; using ALayout = GMMA::ABLayout< 64, 64>; using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<208, 64>; - using CLayout = GMMA::CLayout_64x208; + using BLayout = GMMA::ABLayout< 96, 64>; + using CLayout = GMMA::CLayout_64x96; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) template -struct MMA_Traits> +struct MMA_Traits> { using ValTypeD = half_t; using ValTypeA = sparse_elem<2, float_e5m2_t>; @@ -16549,22 +7390,20 @@ struct MMA_Traits; - using Shape_MNK = Shape<_64,_208,_64>; + using Shape_MNK = Shape<_64,_96,_64>; using ThrID = Layout<_128>; using ALayout = GMMA::ALayout_64x64; using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<208, 64>; - using CLayout = GMMA::CLayout_64x208; + using BLayout = GMMA::ABLayout< 96, 64>; + using CLayout = GMMA::CLayout_64x96; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) template -struct MMA_Traits> +struct MMA_Traits> { using ValTypeD = float; using ValTypeA = sparse_elem<2, float_e5m2_t>; @@ -16575,22 +7414,20 @@ struct MMA_Traits; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_208,_64>; + using Shape_MNK = Shape<_64,_96,_64>; using ThrID = Layout<_128>; using ALayout = GMMA::ABLayout< 64, 64>; using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<208, 64>; - using CLayout = GMMA::CLayout_64x208; + using BLayout = GMMA::ABLayout< 96, 64>; + using CLayout = GMMA::CLayout_64x96; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) template -struct MMA_Traits> +struct MMA_Traits> { using ValTypeD = float; using ValTypeA = sparse_elem<2, float_e5m2_t>; @@ -16600,22 +7437,20 @@ struct MMA_Traits; - using Shape_MNK = Shape<_64,_208,_64>; + using Shape_MNK = Shape<_64,_96,_64>; using ThrID = Layout<_128>; using ALayout = GMMA::ALayout_64x64; using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<208, 64>; - using CLayout = GMMA::CLayout_64x208; + using BLayout = GMMA::ABLayout< 96, 64>; + using CLayout = GMMA::CLayout_64x96; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) template -struct MMA_Traits> +struct MMA_Traits> { using ValTypeD = half_t; using ValTypeA = sparse_elem<2, float_e5m2_t>; @@ -16626,22 +7461,20 @@ struct MMA_Traits; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_224,_64>; + using Shape_MNK = Shape<_64,_128,_64>; using ThrID = Layout<_128>; using ALayout = GMMA::ABLayout< 64, 64>; using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<224, 64>; - using CLayout = GMMA::CLayout_64x224; + using BLayout = GMMA::ABLayout<128, 64>; + using CLayout = GMMA::CLayout_64x128; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) template -struct MMA_Traits> +struct MMA_Traits> { using ValTypeD = half_t; using ValTypeA = sparse_elem<2, float_e5m2_t>; @@ -16651,22 +7484,20 @@ struct MMA_Traits; - using Shape_MNK = Shape<_64,_224,_64>; + using Shape_MNK = Shape<_64,_128,_64>; using ThrID = Layout<_128>; using ALayout = GMMA::ALayout_64x64; using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<224, 64>; - using CLayout = GMMA::CLayout_64x224; + using BLayout = GMMA::ABLayout<128, 64>; + using CLayout = GMMA::CLayout_64x128; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) template -struct MMA_Traits> +struct MMA_Traits> { using ValTypeD = float; using ValTypeA = sparse_elem<2, float_e5m2_t>; @@ -16677,22 +7508,20 @@ struct MMA_Traits; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_224,_64>; + using Shape_MNK = Shape<_64,_128,_64>; using ThrID = Layout<_128>; using ALayout = GMMA::ABLayout< 64, 64>; using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<224, 64>; - using CLayout = GMMA::CLayout_64x224; + using BLayout = GMMA::ABLayout<128, 64>; + using CLayout = GMMA::CLayout_64x128; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) template -struct MMA_Traits> +struct MMA_Traits> { using ValTypeD = float; using ValTypeA = sparse_elem<2, float_e5m2_t>; @@ -16702,22 +7531,20 @@ struct MMA_Traits; - using Shape_MNK = Shape<_64,_224,_64>; + using Shape_MNK = Shape<_64,_128,_64>; using ThrID = Layout<_128>; using ALayout = GMMA::ALayout_64x64; using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<224, 64>; - using CLayout = GMMA::CLayout_64x224; + using BLayout = GMMA::ABLayout<128, 64>; + using CLayout = GMMA::CLayout_64x128; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) template -struct MMA_Traits> +struct MMA_Traits> { using ValTypeD = half_t; using ValTypeA = sparse_elem<2, float_e5m2_t>; @@ -16728,22 +7555,20 @@ struct MMA_Traits; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_240,_64>; + using Shape_MNK = Shape<_64,_192,_64>; using ThrID = Layout<_128>; using ALayout = GMMA::ABLayout< 64, 64>; using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<240, 64>; - using CLayout = GMMA::CLayout_64x240; + using BLayout = GMMA::ABLayout<192, 64>; + using CLayout = GMMA::CLayout_64x192; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) template -struct MMA_Traits> +struct MMA_Traits> { using ValTypeD = half_t; using ValTypeA = sparse_elem<2, float_e5m2_t>; @@ -16753,22 +7578,20 @@ struct MMA_Traits; - using Shape_MNK = Shape<_64,_240,_64>; + using Shape_MNK = Shape<_64,_192,_64>; using ThrID = Layout<_128>; using ALayout = GMMA::ALayout_64x64; using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<240, 64>; - using CLayout = GMMA::CLayout_64x240; + using BLayout = GMMA::ABLayout<192, 64>; + using CLayout = GMMA::CLayout_64x192; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) template -struct MMA_Traits> +struct MMA_Traits> { using ValTypeD = float; using ValTypeA = sparse_elem<2, float_e5m2_t>; @@ -16779,22 +7602,20 @@ struct MMA_Traits; using FrgTypeB = GMMA::smem_desc; - using Shape_MNK = Shape<_64,_240,_64>; + using Shape_MNK = Shape<_64,_192,_64>; using ThrID = Layout<_128>; using ALayout = GMMA::ABLayout< 64, 64>; using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<240, 64>; - using CLayout = GMMA::CLayout_64x240; + using BLayout = GMMA::ABLayout<192, 64>; + using CLayout = GMMA::CLayout_64x192; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) template -struct MMA_Traits> +struct MMA_Traits> { using ValTypeD = float; using ValTypeA = sparse_elem<2, float_e5m2_t>; @@ -16804,16 +7625,15 @@ struct MMA_Traits; - using Shape_MNK = Shape<_64,_240,_64>; + using Shape_MNK = Shape<_64,_192,_64>; using ThrID = Layout<_128>; using ALayout = GMMA::ALayout_64x64; using ELayout = GMMA::ELayout_64x64; - using BLayout = GMMA::ABLayout<240, 64>; - using CLayout = GMMA::CLayout_64x240; + using BLayout = GMMA::ABLayout<192, 64>; + using CLayout = GMMA::CLayout_64x192; GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; }; -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -16909,7 +7729,10 @@ struct MMA_Traits +#include + +namespace cute { + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 24, 32>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 24, 32>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_40,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 40, 32>; + using CLayout = GMMA::CLayout_64x40; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_40,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 40, 32>; + using CLayout = GMMA::CLayout_64x40; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 48, 32>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 48, 32>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_56,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 56, 32>; + using CLayout = GMMA::CLayout_64x56; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_56,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 56, 32>; + using CLayout = GMMA::CLayout_64x56; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_72,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 72, 32>; + using CLayout = GMMA::CLayout_64x72; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_72,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 72, 32>; + using CLayout = GMMA::CLayout_64x72; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 80, 32>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 80, 32>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_88,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 88, 32>; + using CLayout = GMMA::CLayout_64x88; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_88,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 88, 32>; + using CLayout = GMMA::CLayout_64x88; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_104,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<104, 32>; + using CLayout = GMMA::CLayout_64x104; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_104,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<104, 32>; + using CLayout = GMMA::CLayout_64x104; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<112, 32>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<112, 32>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_120,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<120, 32>; + using CLayout = GMMA::CLayout_64x120; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_120,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<120, 32>; + using CLayout = GMMA::CLayout_64x120; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_136,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<136, 32>; + using CLayout = GMMA::CLayout_64x136; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_136,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<136, 32>; + using CLayout = GMMA::CLayout_64x136; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<144, 32>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<144, 32>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_152,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<152, 32>; + using CLayout = GMMA::CLayout_64x152; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_152,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<152, 32>; + using CLayout = GMMA::CLayout_64x152; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<160, 32>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<160, 32>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_168,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<168, 32>; + using CLayout = GMMA::CLayout_64x168; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_168,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<168, 32>; + using CLayout = GMMA::CLayout_64x168; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<176, 32>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<176, 32>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_184,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<184, 32>; + using CLayout = GMMA::CLayout_64x184; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_184,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<184, 32>; + using CLayout = GMMA::CLayout_64x184; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_200,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<200, 32>; + using CLayout = GMMA::CLayout_64x200; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_200,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<200, 32>; + using CLayout = GMMA::CLayout_64x200; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<208, 32>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<208, 32>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_216,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<216, 32>; + using CLayout = GMMA::CLayout_64x216; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_216,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<216, 32>; + using CLayout = GMMA::CLayout_64x216; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<224, 32>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<224, 32>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_232,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<232, 32>; + using CLayout = GMMA::CLayout_64x232; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_232,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<232, 32>; + using CLayout = GMMA::CLayout_64x232; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<240, 32>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<240, 32>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_248,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<248, 32>; + using CLayout = GMMA::CLayout_64x248; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_248,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<248, 32>; + using CLayout = GMMA::CLayout_64x248; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 24, 32>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 24, 32>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_40,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 40, 32>; + using CLayout = GMMA::CLayout_64x40; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_40,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 40, 32>; + using CLayout = GMMA::CLayout_64x40; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 48, 32>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 48, 32>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_56,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 56, 32>; + using CLayout = GMMA::CLayout_64x56; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_56,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 56, 32>; + using CLayout = GMMA::CLayout_64x56; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_72,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 72, 32>; + using CLayout = GMMA::CLayout_64x72; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_72,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 72, 32>; + using CLayout = GMMA::CLayout_64x72; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 80, 32>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 80, 32>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_88,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 88, 32>; + using CLayout = GMMA::CLayout_64x88; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_88,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 88, 32>; + using CLayout = GMMA::CLayout_64x88; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_104,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<104, 32>; + using CLayout = GMMA::CLayout_64x104; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_104,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<104, 32>; + using CLayout = GMMA::CLayout_64x104; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<112, 32>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<112, 32>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_120,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<120, 32>; + using CLayout = GMMA::CLayout_64x120; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_120,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<120, 32>; + using CLayout = GMMA::CLayout_64x120; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_136,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<136, 32>; + using CLayout = GMMA::CLayout_64x136; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_136,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<136, 32>; + using CLayout = GMMA::CLayout_64x136; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<144, 32>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<144, 32>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_152,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<152, 32>; + using CLayout = GMMA::CLayout_64x152; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_152,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<152, 32>; + using CLayout = GMMA::CLayout_64x152; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<160, 32>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<160, 32>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_168,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<168, 32>; + using CLayout = GMMA::CLayout_64x168; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_168,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<168, 32>; + using CLayout = GMMA::CLayout_64x168; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<176, 32>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<176, 32>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_184,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<184, 32>; + using CLayout = GMMA::CLayout_64x184; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_184,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<184, 32>; + using CLayout = GMMA::CLayout_64x184; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_200,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<200, 32>; + using CLayout = GMMA::CLayout_64x200; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_200,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<200, 32>; + using CLayout = GMMA::CLayout_64x200; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<208, 32>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<208, 32>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_216,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<216, 32>; + using CLayout = GMMA::CLayout_64x216; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_216,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<216, 32>; + using CLayout = GMMA::CLayout_64x216; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<224, 32>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<224, 32>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_232,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<232, 32>; + using CLayout = GMMA::CLayout_64x232; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_232,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<232, 32>; + using CLayout = GMMA::CLayout_64x232; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<240, 32>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<240, 32>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_248,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<248, 32>; + using CLayout = GMMA::CLayout_64x248; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_248,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<248, 32>; + using CLayout = GMMA::CLayout_64x248; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 24, 32>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 24, 32>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_40,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 40, 32>; + using CLayout = GMMA::CLayout_64x40; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_40,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 40, 32>; + using CLayout = GMMA::CLayout_64x40; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 48, 32>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 48, 32>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_56,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 56, 32>; + using CLayout = GMMA::CLayout_64x56; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_56,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 56, 32>; + using CLayout = GMMA::CLayout_64x56; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_72,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 72, 32>; + using CLayout = GMMA::CLayout_64x72; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_72,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 72, 32>; + using CLayout = GMMA::CLayout_64x72; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 80, 32>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 80, 32>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_88,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 88, 32>; + using CLayout = GMMA::CLayout_64x88; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_88,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 88, 32>; + using CLayout = GMMA::CLayout_64x88; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_104,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<104, 32>; + using CLayout = GMMA::CLayout_64x104; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_104,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<104, 32>; + using CLayout = GMMA::CLayout_64x104; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<112, 32>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<112, 32>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_120,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<120, 32>; + using CLayout = GMMA::CLayout_64x120; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_120,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<120, 32>; + using CLayout = GMMA::CLayout_64x120; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_136,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<136, 32>; + using CLayout = GMMA::CLayout_64x136; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_136,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<136, 32>; + using CLayout = GMMA::CLayout_64x136; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<144, 32>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<144, 32>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_152,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<152, 32>; + using CLayout = GMMA::CLayout_64x152; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_152,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<152, 32>; + using CLayout = GMMA::CLayout_64x152; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<160, 32>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<160, 32>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_168,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<168, 32>; + using CLayout = GMMA::CLayout_64x168; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_168,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<168, 32>; + using CLayout = GMMA::CLayout_64x168; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<176, 32>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<176, 32>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_184,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<184, 32>; + using CLayout = GMMA::CLayout_64x184; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_184,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<184, 32>; + using CLayout = GMMA::CLayout_64x184; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_200,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<200, 32>; + using CLayout = GMMA::CLayout_64x200; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_200,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<200, 32>; + using CLayout = GMMA::CLayout_64x200; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<208, 32>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<208, 32>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_216,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<216, 32>; + using CLayout = GMMA::CLayout_64x216; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_216,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<216, 32>; + using CLayout = GMMA::CLayout_64x216; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<224, 32>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<224, 32>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_232,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<232, 32>; + using CLayout = GMMA::CLayout_64x232; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_232,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<232, 32>; + using CLayout = GMMA::CLayout_64x232; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<240, 32>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<240, 32>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_248,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<248, 32>; + using CLayout = GMMA::CLayout_64x248; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_248,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<248, 32>; + using CLayout = GMMA::CLayout_64x248; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout< 24, 16>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout< 24, 16>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_40,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout< 40, 16>; + using CLayout = GMMA::CLayout_64x40; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_40,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout< 40, 16>; + using CLayout = GMMA::CLayout_64x40; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout< 48, 16>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout< 48, 16>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_56,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout< 56, 16>; + using CLayout = GMMA::CLayout_64x56; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_56,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout< 56, 16>; + using CLayout = GMMA::CLayout_64x56; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_72,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout< 72, 16>; + using CLayout = GMMA::CLayout_64x72; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_72,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout< 72, 16>; + using CLayout = GMMA::CLayout_64x72; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout< 80, 16>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout< 80, 16>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_88,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout< 88, 16>; + using CLayout = GMMA::CLayout_64x88; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_88,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout< 88, 16>; + using CLayout = GMMA::CLayout_64x88; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_104,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout<104, 16>; + using CLayout = GMMA::CLayout_64x104; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_104,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout<104, 16>; + using CLayout = GMMA::CLayout_64x104; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout<112, 16>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout<112, 16>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_120,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout<120, 16>; + using CLayout = GMMA::CLayout_64x120; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_120,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout<120, 16>; + using CLayout = GMMA::CLayout_64x120; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_136,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout<136, 16>; + using CLayout = GMMA::CLayout_64x136; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_136,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout<136, 16>; + using CLayout = GMMA::CLayout_64x136; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout<144, 16>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout<144, 16>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_152,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout<152, 16>; + using CLayout = GMMA::CLayout_64x152; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_152,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout<152, 16>; + using CLayout = GMMA::CLayout_64x152; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout<160, 16>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout<160, 16>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_168,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout<168, 16>; + using CLayout = GMMA::CLayout_64x168; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_168,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout<168, 16>; + using CLayout = GMMA::CLayout_64x168; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout<176, 16>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout<176, 16>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_184,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout<184, 16>; + using CLayout = GMMA::CLayout_64x184; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_184,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout<184, 16>; + using CLayout = GMMA::CLayout_64x184; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_200,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout<200, 16>; + using CLayout = GMMA::CLayout_64x200; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_200,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout<200, 16>; + using CLayout = GMMA::CLayout_64x200; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout<208, 16>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout<208, 16>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_216,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout<216, 16>; + using CLayout = GMMA::CLayout_64x216; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_216,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout<216, 16>; + using CLayout = GMMA::CLayout_64x216; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout<224, 16>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout<224, 16>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_232,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout<232, 16>; + using CLayout = GMMA::CLayout_64x232; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_232,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout<232, 16>; + using CLayout = GMMA::CLayout_64x232; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout<240, 16>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout<240, 16>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_248,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout<248, 16>; + using CLayout = GMMA::CLayout_64x248; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_248,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout<248, 16>; + using CLayout = GMMA::CLayout_64x248; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 24, 64>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 24, 64>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 48, 64>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 48, 64>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 80, 64>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 80, 64>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<112, 64>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<112, 64>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<144, 64>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<144, 64>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<160, 64>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<160, 64>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<176, 64>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<176, 64>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<208, 64>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<208, 64>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<224, 64>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<224, 64>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<240, 64>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<240, 64>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 24, 64>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 24, 64>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 48, 64>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 48, 64>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 80, 64>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 80, 64>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<112, 64>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<112, 64>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<144, 64>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<144, 64>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<160, 64>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<160, 64>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<176, 64>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<176, 64>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<208, 64>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<208, 64>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<224, 64>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<224, 64>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<240, 64>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<240, 64>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 24, 64>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 24, 64>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 48, 64>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 48, 64>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 80, 64>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 80, 64>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<112, 64>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<112, 64>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<144, 64>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<144, 64>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<160, 64>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<160, 64>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<176, 64>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<176, 64>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<208, 64>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<208, 64>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<224, 64>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<224, 64>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<240, 64>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<240, 64>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 24, 64>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 24, 64>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 48, 64>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 48, 64>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 80, 64>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 80, 64>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<112, 64>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<112, 64>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<144, 64>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<144, 64>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<160, 64>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<160, 64>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<176, 64>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<176, 64>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<208, 64>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<208, 64>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<224, 64>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<224, 64>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<240, 64>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<240, 64>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 24, 64>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 24, 64>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 48, 64>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 48, 64>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 80, 64>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 80, 64>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<112, 64>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<112, 64>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<144, 64>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<144, 64>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<160, 64>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<160, 64>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<176, 64>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<176, 64>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<208, 64>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<208, 64>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<224, 64>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<224, 64>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<240, 64>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<240, 64>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 24, 64>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 24, 64>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 48, 64>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 48, 64>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 80, 64>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 80, 64>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<112, 64>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<112, 64>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<144, 64>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<144, 64>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<160, 64>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<160, 64>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<176, 64>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<176, 64>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<208, 64>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<208, 64>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<224, 64>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<224, 64>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<240, 64>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<240, 64>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 24, 64>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 24, 64>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 48, 64>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 48, 64>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 80, 64>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 80, 64>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<112, 64>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<112, 64>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<144, 64>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<144, 64>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<160, 64>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<160, 64>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<176, 64>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<176, 64>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<208, 64>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<208, 64>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<224, 64>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<224, 64>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<240, 64>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<240, 64>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 24, 64>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 24, 64>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 48, 64>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 48, 64>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 80, 64>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 80, 64>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<112, 64>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<112, 64>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<144, 64>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<144, 64>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<160, 64>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<160, 64>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<176, 64>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<176, 64>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<208, 64>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<208, 64>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<224, 64>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<224, 64>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<240, 64>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<240, 64>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 24, 64>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 24, 64>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 24, 64>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 24, 64>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_40,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 40, 64>; + using CLayout = GMMA::CLayout_64x40; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_40,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 40, 64>; + using CLayout = GMMA::CLayout_64x40; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_40,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 40, 64>; + using CLayout = GMMA::CLayout_64x40; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_40,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 40, 64>; + using CLayout = GMMA::CLayout_64x40; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 48, 64>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 48, 64>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 48, 64>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 48, 64>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_56,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 56, 64>; + using CLayout = GMMA::CLayout_64x56; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_56,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 56, 64>; + using CLayout = GMMA::CLayout_64x56; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_56,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 56, 64>; + using CLayout = GMMA::CLayout_64x56; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_56,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 56, 64>; + using CLayout = GMMA::CLayout_64x56; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_72,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 72, 64>; + using CLayout = GMMA::CLayout_64x72; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_72,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 72, 64>; + using CLayout = GMMA::CLayout_64x72; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_72,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 72, 64>; + using CLayout = GMMA::CLayout_64x72; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_72,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 72, 64>; + using CLayout = GMMA::CLayout_64x72; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 80, 64>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 80, 64>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 80, 64>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 80, 64>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_88,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 88, 64>; + using CLayout = GMMA::CLayout_64x88; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_88,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 88, 64>; + using CLayout = GMMA::CLayout_64x88; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_88,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 88, 64>; + using CLayout = GMMA::CLayout_64x88; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_88,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 88, 64>; + using CLayout = GMMA::CLayout_64x88; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_104,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<104, 64>; + using CLayout = GMMA::CLayout_64x104; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_104,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<104, 64>; + using CLayout = GMMA::CLayout_64x104; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_104,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<104, 64>; + using CLayout = GMMA::CLayout_64x104; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_104,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<104, 64>; + using CLayout = GMMA::CLayout_64x104; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<112, 64>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<112, 64>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<112, 64>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<112, 64>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_120,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<120, 64>; + using CLayout = GMMA::CLayout_64x120; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_120,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<120, 64>; + using CLayout = GMMA::CLayout_64x120; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_120,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<120, 64>; + using CLayout = GMMA::CLayout_64x120; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_120,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<120, 64>; + using CLayout = GMMA::CLayout_64x120; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_136,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<136, 64>; + using CLayout = GMMA::CLayout_64x136; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_136,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<136, 64>; + using CLayout = GMMA::CLayout_64x136; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_136,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<136, 64>; + using CLayout = GMMA::CLayout_64x136; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_136,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<136, 64>; + using CLayout = GMMA::CLayout_64x136; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<144, 64>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<144, 64>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<144, 64>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<144, 64>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_152,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<152, 64>; + using CLayout = GMMA::CLayout_64x152; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_152,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<152, 64>; + using CLayout = GMMA::CLayout_64x152; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_152,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<152, 64>; + using CLayout = GMMA::CLayout_64x152; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_152,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<152, 64>; + using CLayout = GMMA::CLayout_64x152; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<160, 64>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<160, 64>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<160, 64>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<160, 64>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_168,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<168, 64>; + using CLayout = GMMA::CLayout_64x168; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_168,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<168, 64>; + using CLayout = GMMA::CLayout_64x168; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_168,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<168, 64>; + using CLayout = GMMA::CLayout_64x168; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_168,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<168, 64>; + using CLayout = GMMA::CLayout_64x168; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<176, 64>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<176, 64>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<176, 64>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<176, 64>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_184,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<184, 64>; + using CLayout = GMMA::CLayout_64x184; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_184,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<184, 64>; + using CLayout = GMMA::CLayout_64x184; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_184,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<184, 64>; + using CLayout = GMMA::CLayout_64x184; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_184,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<184, 64>; + using CLayout = GMMA::CLayout_64x184; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_200,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<200, 64>; + using CLayout = GMMA::CLayout_64x200; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_200,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<200, 64>; + using CLayout = GMMA::CLayout_64x200; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_200,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<200, 64>; + using CLayout = GMMA::CLayout_64x200; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_200,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<200, 64>; + using CLayout = GMMA::CLayout_64x200; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<208, 64>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<208, 64>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<208, 64>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<208, 64>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_216,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<216, 64>; + using CLayout = GMMA::CLayout_64x216; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_216,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<216, 64>; + using CLayout = GMMA::CLayout_64x216; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_216,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<216, 64>; + using CLayout = GMMA::CLayout_64x216; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_216,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<216, 64>; + using CLayout = GMMA::CLayout_64x216; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<224, 64>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<224, 64>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<224, 64>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<224, 64>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_232,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<232, 64>; + using CLayout = GMMA::CLayout_64x232; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_232,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<232, 64>; + using CLayout = GMMA::CLayout_64x232; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_232,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<232, 64>; + using CLayout = GMMA::CLayout_64x232; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_232,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<232, 64>; + using CLayout = GMMA::CLayout_64x232; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<240, 64>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<240, 64>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<240, 64>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<240, 64>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_248,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<248, 64>; + using CLayout = GMMA::CLayout_64x248; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_248,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<248, 64>; + using CLayout = GMMA::CLayout_64x248; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_248,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<248, 64>; + using CLayout = GMMA::CLayout_64x248; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_248,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<248, 64>; + using CLayout = GMMA::CLayout_64x248; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 24, 64>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 24, 64>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 24, 64>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 24, 64>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_40,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 40, 64>; + using CLayout = GMMA::CLayout_64x40; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_40,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 40, 64>; + using CLayout = GMMA::CLayout_64x40; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_40,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 40, 64>; + using CLayout = GMMA::CLayout_64x40; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_40,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 40, 64>; + using CLayout = GMMA::CLayout_64x40; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 48, 64>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 48, 64>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 48, 64>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 48, 64>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_56,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 56, 64>; + using CLayout = GMMA::CLayout_64x56; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_56,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 56, 64>; + using CLayout = GMMA::CLayout_64x56; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_56,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 56, 64>; + using CLayout = GMMA::CLayout_64x56; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_56,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 56, 64>; + using CLayout = GMMA::CLayout_64x56; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_72,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 72, 64>; + using CLayout = GMMA::CLayout_64x72; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_72,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 72, 64>; + using CLayout = GMMA::CLayout_64x72; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_72,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 72, 64>; + using CLayout = GMMA::CLayout_64x72; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_72,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 72, 64>; + using CLayout = GMMA::CLayout_64x72; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 80, 64>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 80, 64>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 80, 64>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 80, 64>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_88,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 88, 64>; + using CLayout = GMMA::CLayout_64x88; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_88,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 88, 64>; + using CLayout = GMMA::CLayout_64x88; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_88,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 88, 64>; + using CLayout = GMMA::CLayout_64x88; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_88,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 88, 64>; + using CLayout = GMMA::CLayout_64x88; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_104,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<104, 64>; + using CLayout = GMMA::CLayout_64x104; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_104,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<104, 64>; + using CLayout = GMMA::CLayout_64x104; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_104,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<104, 64>; + using CLayout = GMMA::CLayout_64x104; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_104,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<104, 64>; + using CLayout = GMMA::CLayout_64x104; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<112, 64>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<112, 64>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<112, 64>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<112, 64>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_120,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<120, 64>; + using CLayout = GMMA::CLayout_64x120; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_120,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<120, 64>; + using CLayout = GMMA::CLayout_64x120; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_120,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<120, 64>; + using CLayout = GMMA::CLayout_64x120; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_120,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<120, 64>; + using CLayout = GMMA::CLayout_64x120; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_136,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<136, 64>; + using CLayout = GMMA::CLayout_64x136; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_136,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<136, 64>; + using CLayout = GMMA::CLayout_64x136; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_136,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<136, 64>; + using CLayout = GMMA::CLayout_64x136; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_136,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<136, 64>; + using CLayout = GMMA::CLayout_64x136; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<144, 64>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<144, 64>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<144, 64>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<144, 64>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_152,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<152, 64>; + using CLayout = GMMA::CLayout_64x152; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_152,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<152, 64>; + using CLayout = GMMA::CLayout_64x152; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_152,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<152, 64>; + using CLayout = GMMA::CLayout_64x152; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_152,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<152, 64>; + using CLayout = GMMA::CLayout_64x152; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<160, 64>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<160, 64>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<160, 64>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<160, 64>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_168,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<168, 64>; + using CLayout = GMMA::CLayout_64x168; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_168,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<168, 64>; + using CLayout = GMMA::CLayout_64x168; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_168,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<168, 64>; + using CLayout = GMMA::CLayout_64x168; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_168,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<168, 64>; + using CLayout = GMMA::CLayout_64x168; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<176, 64>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<176, 64>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<176, 64>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<176, 64>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_184,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<184, 64>; + using CLayout = GMMA::CLayout_64x184; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_184,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<184, 64>; + using CLayout = GMMA::CLayout_64x184; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_184,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<184, 64>; + using CLayout = GMMA::CLayout_64x184; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_184,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<184, 64>; + using CLayout = GMMA::CLayout_64x184; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_200,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<200, 64>; + using CLayout = GMMA::CLayout_64x200; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_200,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<200, 64>; + using CLayout = GMMA::CLayout_64x200; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_200,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<200, 64>; + using CLayout = GMMA::CLayout_64x200; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_200,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<200, 64>; + using CLayout = GMMA::CLayout_64x200; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<208, 64>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<208, 64>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<208, 64>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<208, 64>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_216,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<216, 64>; + using CLayout = GMMA::CLayout_64x216; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_216,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<216, 64>; + using CLayout = GMMA::CLayout_64x216; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_216,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<216, 64>; + using CLayout = GMMA::CLayout_64x216; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_216,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<216, 64>; + using CLayout = GMMA::CLayout_64x216; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<224, 64>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<224, 64>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<224, 64>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<224, 64>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_232,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<232, 64>; + using CLayout = GMMA::CLayout_64x232; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_232,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<232, 64>; + using CLayout = GMMA::CLayout_64x232; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_232,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<232, 64>; + using CLayout = GMMA::CLayout_64x232; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_232,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<232, 64>; + using CLayout = GMMA::CLayout_64x232; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<240, 64>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<240, 64>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<240, 64>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<240, 64>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_248,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<248, 64>; + using CLayout = GMMA::CLayout_64x248; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_248,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<248, 64>; + using CLayout = GMMA::CLayout_64x248; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_248,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<248, 64>; + using CLayout = GMMA::CLayout_64x248; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_248,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<248, 64>; + using CLayout = GMMA::CLayout_64x248; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 24, 64>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 24, 64>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 24, 64>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 24, 64>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_40,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 40, 64>; + using CLayout = GMMA::CLayout_64x40; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_40,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 40, 64>; + using CLayout = GMMA::CLayout_64x40; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_40,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 40, 64>; + using CLayout = GMMA::CLayout_64x40; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_40,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 40, 64>; + using CLayout = GMMA::CLayout_64x40; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 48, 64>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 48, 64>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 48, 64>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 48, 64>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_56,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 56, 64>; + using CLayout = GMMA::CLayout_64x56; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_56,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 56, 64>; + using CLayout = GMMA::CLayout_64x56; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_56,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 56, 64>; + using CLayout = GMMA::CLayout_64x56; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_56,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 56, 64>; + using CLayout = GMMA::CLayout_64x56; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_72,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 72, 64>; + using CLayout = GMMA::CLayout_64x72; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_72,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 72, 64>; + using CLayout = GMMA::CLayout_64x72; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_72,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 72, 64>; + using CLayout = GMMA::CLayout_64x72; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_72,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 72, 64>; + using CLayout = GMMA::CLayout_64x72; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 80, 64>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 80, 64>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 80, 64>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 80, 64>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_88,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 88, 64>; + using CLayout = GMMA::CLayout_64x88; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_88,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 88, 64>; + using CLayout = GMMA::CLayout_64x88; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_88,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 88, 64>; + using CLayout = GMMA::CLayout_64x88; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_88,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 88, 64>; + using CLayout = GMMA::CLayout_64x88; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_104,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<104, 64>; + using CLayout = GMMA::CLayout_64x104; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_104,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<104, 64>; + using CLayout = GMMA::CLayout_64x104; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_104,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<104, 64>; + using CLayout = GMMA::CLayout_64x104; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_104,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<104, 64>; + using CLayout = GMMA::CLayout_64x104; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<112, 64>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<112, 64>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<112, 64>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<112, 64>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_120,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<120, 64>; + using CLayout = GMMA::CLayout_64x120; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_120,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<120, 64>; + using CLayout = GMMA::CLayout_64x120; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_120,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<120, 64>; + using CLayout = GMMA::CLayout_64x120; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_120,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<120, 64>; + using CLayout = GMMA::CLayout_64x120; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_136,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<136, 64>; + using CLayout = GMMA::CLayout_64x136; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_136,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<136, 64>; + using CLayout = GMMA::CLayout_64x136; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_136,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<136, 64>; + using CLayout = GMMA::CLayout_64x136; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_136,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<136, 64>; + using CLayout = GMMA::CLayout_64x136; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<144, 64>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<144, 64>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<144, 64>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<144, 64>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_152,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<152, 64>; + using CLayout = GMMA::CLayout_64x152; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_152,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<152, 64>; + using CLayout = GMMA::CLayout_64x152; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_152,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<152, 64>; + using CLayout = GMMA::CLayout_64x152; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_152,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<152, 64>; + using CLayout = GMMA::CLayout_64x152; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<160, 64>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<160, 64>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<160, 64>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<160, 64>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_168,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<168, 64>; + using CLayout = GMMA::CLayout_64x168; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_168,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<168, 64>; + using CLayout = GMMA::CLayout_64x168; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_168,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<168, 64>; + using CLayout = GMMA::CLayout_64x168; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_168,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<168, 64>; + using CLayout = GMMA::CLayout_64x168; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<176, 64>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<176, 64>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<176, 64>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<176, 64>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_184,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<184, 64>; + using CLayout = GMMA::CLayout_64x184; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_184,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<184, 64>; + using CLayout = GMMA::CLayout_64x184; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_184,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<184, 64>; + using CLayout = GMMA::CLayout_64x184; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_184,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<184, 64>; + using CLayout = GMMA::CLayout_64x184; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_200,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<200, 64>; + using CLayout = GMMA::CLayout_64x200; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_200,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<200, 64>; + using CLayout = GMMA::CLayout_64x200; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_200,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<200, 64>; + using CLayout = GMMA::CLayout_64x200; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_200,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<200, 64>; + using CLayout = GMMA::CLayout_64x200; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<208, 64>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<208, 64>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<208, 64>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<208, 64>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_216,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<216, 64>; + using CLayout = GMMA::CLayout_64x216; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_216,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<216, 64>; + using CLayout = GMMA::CLayout_64x216; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_216,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<216, 64>; + using CLayout = GMMA::CLayout_64x216; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_216,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<216, 64>; + using CLayout = GMMA::CLayout_64x216; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<224, 64>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<224, 64>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<224, 64>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<224, 64>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_232,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<232, 64>; + using CLayout = GMMA::CLayout_64x232; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_232,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<232, 64>; + using CLayout = GMMA::CLayout_64x232; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_232,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<232, 64>; + using CLayout = GMMA::CLayout_64x232; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_232,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<232, 64>; + using CLayout = GMMA::CLayout_64x232; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<240, 64>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<240, 64>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<240, 64>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<240, 64>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_248,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<248, 64>; + using CLayout = GMMA::CLayout_64x248; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_248,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<248, 64>; + using CLayout = GMMA::CLayout_64x248; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_248,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<248, 64>; + using CLayout = GMMA::CLayout_64x248; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_248,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<248, 64>; + using CLayout = GMMA::CLayout_64x248; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 24, 64>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 24, 64>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 24, 64>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 24, 64>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_40,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 40, 64>; + using CLayout = GMMA::CLayout_64x40; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_40,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 40, 64>; + using CLayout = GMMA::CLayout_64x40; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_40,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 40, 64>; + using CLayout = GMMA::CLayout_64x40; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_40,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 40, 64>; + using CLayout = GMMA::CLayout_64x40; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 48, 64>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 48, 64>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 48, 64>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 48, 64>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_56,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 56, 64>; + using CLayout = GMMA::CLayout_64x56; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_56,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 56, 64>; + using CLayout = GMMA::CLayout_64x56; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_56,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 56, 64>; + using CLayout = GMMA::CLayout_64x56; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_56,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 56, 64>; + using CLayout = GMMA::CLayout_64x56; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_72,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 72, 64>; + using CLayout = GMMA::CLayout_64x72; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_72,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 72, 64>; + using CLayout = GMMA::CLayout_64x72; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_72,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 72, 64>; + using CLayout = GMMA::CLayout_64x72; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_72,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 72, 64>; + using CLayout = GMMA::CLayout_64x72; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 80, 64>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 80, 64>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 80, 64>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 80, 64>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_88,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 88, 64>; + using CLayout = GMMA::CLayout_64x88; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_88,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 88, 64>; + using CLayout = GMMA::CLayout_64x88; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_88,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 88, 64>; + using CLayout = GMMA::CLayout_64x88; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_88,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 88, 64>; + using CLayout = GMMA::CLayout_64x88; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_104,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<104, 64>; + using CLayout = GMMA::CLayout_64x104; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_104,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<104, 64>; + using CLayout = GMMA::CLayout_64x104; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_104,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<104, 64>; + using CLayout = GMMA::CLayout_64x104; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_104,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<104, 64>; + using CLayout = GMMA::CLayout_64x104; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<112, 64>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<112, 64>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<112, 64>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<112, 64>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_120,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<120, 64>; + using CLayout = GMMA::CLayout_64x120; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_120,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<120, 64>; + using CLayout = GMMA::CLayout_64x120; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_120,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<120, 64>; + using CLayout = GMMA::CLayout_64x120; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_120,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<120, 64>; + using CLayout = GMMA::CLayout_64x120; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_136,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<136, 64>; + using CLayout = GMMA::CLayout_64x136; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_136,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<136, 64>; + using CLayout = GMMA::CLayout_64x136; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_136,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<136, 64>; + using CLayout = GMMA::CLayout_64x136; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_136,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<136, 64>; + using CLayout = GMMA::CLayout_64x136; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<144, 64>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<144, 64>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<144, 64>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<144, 64>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_152,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<152, 64>; + using CLayout = GMMA::CLayout_64x152; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_152,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<152, 64>; + using CLayout = GMMA::CLayout_64x152; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_152,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<152, 64>; + using CLayout = GMMA::CLayout_64x152; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_152,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<152, 64>; + using CLayout = GMMA::CLayout_64x152; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<160, 64>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<160, 64>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<160, 64>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<160, 64>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_168,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<168, 64>; + using CLayout = GMMA::CLayout_64x168; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_168,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<168, 64>; + using CLayout = GMMA::CLayout_64x168; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_168,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<168, 64>; + using CLayout = GMMA::CLayout_64x168; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_168,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<168, 64>; + using CLayout = GMMA::CLayout_64x168; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<176, 64>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<176, 64>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<176, 64>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<176, 64>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_184,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<184, 64>; + using CLayout = GMMA::CLayout_64x184; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_184,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<184, 64>; + using CLayout = GMMA::CLayout_64x184; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_184,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<184, 64>; + using CLayout = GMMA::CLayout_64x184; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_184,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<184, 64>; + using CLayout = GMMA::CLayout_64x184; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_200,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<200, 64>; + using CLayout = GMMA::CLayout_64x200; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_200,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<200, 64>; + using CLayout = GMMA::CLayout_64x200; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_200,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<200, 64>; + using CLayout = GMMA::CLayout_64x200; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_200,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<200, 64>; + using CLayout = GMMA::CLayout_64x200; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<208, 64>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<208, 64>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<208, 64>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<208, 64>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_216,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<216, 64>; + using CLayout = GMMA::CLayout_64x216; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_216,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<216, 64>; + using CLayout = GMMA::CLayout_64x216; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_216,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<216, 64>; + using CLayout = GMMA::CLayout_64x216; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_216,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<216, 64>; + using CLayout = GMMA::CLayout_64x216; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<224, 64>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<224, 64>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<224, 64>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<224, 64>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_232,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<232, 64>; + using CLayout = GMMA::CLayout_64x232; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_232,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<232, 64>; + using CLayout = GMMA::CLayout_64x232; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_232,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<232, 64>; + using CLayout = GMMA::CLayout_64x232; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_232,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<232, 64>; + using CLayout = GMMA::CLayout_64x232; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<240, 64>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<240, 64>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<240, 64>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<240, 64>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_248,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<248, 64>; + using CLayout = GMMA::CLayout_64x248; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_248,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<248, 64>; + using CLayout = GMMA::CLayout_64x248; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_248,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<248, 64>; + using CLayout = GMMA::CLayout_64x248; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_248,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<248, 64>; + using CLayout = GMMA::CLayout_64x248; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // end namespace cute diff --git a/include/cute/numeric/integral_constant.hpp b/include/cute/numeric/integral_constant.hpp index e447103b9..3a8d036ee 100644 --- a/include/cute/numeric/integral_constant.hpp +++ b/include/cute/numeric/integral_constant.hpp @@ -146,19 +146,33 @@ using _12 = Int<12>; using _16 = Int<16>; using _24 = Int<24>; using _32 = Int<32>; +using _40 = Int<40>; using _48 = Int<48>; +using _56 = Int<56>; using _64 = Int<64>; +using _72 = Int<72>; using _80 = Int<80>; +using _88 = Int<88>; using _96 = Int<96>; +using _104 = Int<104>; using _112 = Int<112>; +using _120 = Int<120>; using _128 = Int<128>; +using _136 = Int<136>; using _144 = Int<144>; +using _152 = Int<152>; using _160 = Int<160>; +using _168 = Int<168>; using _176 = Int<176>; +using _184 = Int<184>; using _192 = Int<192>; +using _200 = Int<200>; using _208 = Int<208>; +using _216 = Int<216>; using _224 = Int<224>; +using _232 = Int<232>; using _240 = Int<240>; +using _248 = Int<248>; using _256 = Int<256>; using _384 = Int<384>; using _512 = Int<512>;